Gabriel Bibbó commited on
Commit
dac6057
·
1 Parent(s): e78c137

adjust app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -63
app.py CHANGED
@@ -522,7 +522,12 @@ class OptimizedAST:
522
  self.model = self.model.half()
523
  print(f"✅ {self.model_name} loaded with FP16 optimization")
524
  else:
525
- print(f"✅ {self.model_name} loaded successfully")
 
 
 
 
 
526
 
527
  self.model.eval()
528
  else:
@@ -665,16 +670,16 @@ class AudioProcessor:
665
  "WebRTC-VAD": 0.03, # 30ms frames (480 samples)
666
  "E-PANNs": 1.0, # CHANGED from 6.0 to 1.0 for better temporal resolution
667
  "PANNs": 1.0, # CHANGED from 10.0 to 1.0 for better temporal resolution
668
- "AST": 1.0 # 1 second for better temporal resolution
669
  }
670
 
671
- # Model-specific hop sizes for efficiency - INCREASED to 20Hz
672
  self.model_hop_sizes = {
673
  "Silero-VAD": 0.016, # 16ms hop for Silero (512 samples window)
674
  "WebRTC-VAD": 0.03, # 30ms hop for WebRTC (match frame duration)
675
  "E-PANNs": 0.05, # CHANGED from 0.1 to 0.05 for 20Hz
676
  "PANNs": 0.05, # CHANGED from 0.1 to 0.05 for 20Hz
677
- "AST": 0.05 # CHANGED from 0.1 to 0.05 for 20Hz
678
  }
679
 
680
  # Model-specific thresholds for better detection
@@ -724,7 +729,7 @@ class AudioProcessor:
724
  hop_length=self.hop_length,
725
  win_length=self.n_fft,
726
  window='hann',
727
- center=True # CAMBIO 2: True para alineación con timestamps centrados
728
  )
729
 
730
  power_spec = np.abs(stft) ** 2
@@ -781,80 +786,67 @@ class AudioProcessor:
781
  return dummy_spec, dummy_time
782
 
783
  def detect_onset_offset_advanced(self, vad_results: List[VADResult],
784
- model_thresholds: Dict[str, float]) -> List[OnsetOffset]:
 
 
785
  """
786
- CAMBIO 4: Cruces exactos de umbral, sin suavizado ni histéresis.
787
  Onset: p[i-1] < thr y p[i] >= thr
788
  Offset: p[i-1] >= thr y p[i] < thr
789
  El instante se obtiene por interpolación lineal entre (t[i-1], p[i-1]) y (t[i], p[i]).
790
  """
791
- onsets_offsets: List[OnsetOffset] = []
792
  if len(vad_results) < 2:
793
  return onsets_offsets
794
 
795
- # agrupar por modelo (base_name)
796
- grouped: Dict[str, List[VADResult]] = {}
797
  for r in vad_results:
798
  base = r.model_name.split('(')[0].strip()
799
- grouped.setdefault(base, []).append(r)
 
 
 
800
 
801
  for base, rs in grouped.items():
802
  rs.sort(key=lambda r: r.timestamp)
803
  t = np.array([r.timestamp for r in rs], dtype=float)
804
  p = np.array([r.probability for r in rs], dtype=float)
805
- thr = float(model_thresholds.get(base, 0.5))
806
 
807
  in_seg = False
808
  onset_t = None
809
 
810
- # si arrancamos por encima del umbral
811
  if p[0] > thr:
812
  in_seg = True
813
  onset_t = t[0]
814
 
 
 
 
 
 
815
  for i in range(1, len(p)):
816
  p0, p1 = p[i-1], p[i]
817
  t0, t1 = t[i-1], t[i]
818
-
819
- # ONSET: p0 < thr y p1 >= thr
820
  if (not in_seg) and (p0 < thr) and (p1 >= thr):
821
- if p1 == p0:
822
- cross = t1
823
- else:
824
- alpha = (thr - p0) / (p1 - p0)
825
- cross = t0 + alpha * (t1 - t0)
826
- onset_t = cross
827
  in_seg = True
828
-
829
- # OFFSET: p0 >= thr y p1 < thr
830
  elif in_seg and (p0 >= thr) and (p1 < thr):
831
- if p1 == p0:
832
- cross = t1
833
- else:
834
- alpha = (thr - p0) / (p1 - p0)
835
- cross = t0 + alpha * (t1 - t0)
836
- # confianza como media de probs dentro del segmento (crudas)
837
- mask = (t >= onset_t) & (t <= cross)
838
- conf = float(p[mask].mean()) if np.any(mask) else float(max(p0, p1))
839
- onsets_offsets.append(OnsetOffset(
840
- onset_time=max(0.0, float(onset_t)),
841
- offset_time=float(cross),
842
- model_name=base,
843
- confidence=conf
844
- ))
845
  in_seg = False
846
  onset_t = None
847
 
848
- # si termina por encima del umbral, cerramos en el último timestamp
849
  if in_seg and onset_t is not None:
850
- mask = (t >= onset_t)
851
- conf = float(p[mask].mean()) if np.any(mask) else float(p[-1])
852
- onsets_offsets.append(OnsetOffset(
853
- onset_time=max(0.0, float(onset_t)),
854
- offset_time=float(t[-1]),
855
- model_name=base,
856
- confidence=conf
857
- ))
858
 
859
  return onsets_offsets
860
 
@@ -949,9 +941,9 @@ def create_realtime_plot(audio_data: np.ndarray, vad_results: List[VADResult],
949
  row=2, col=1
950
  )
951
 
952
- # Use model-specific thresholds
953
- thr_a = processor.model_thresholds.get(model_a, threshold)
954
- thr_b = processor.model_thresholds.get(model_b, threshold)
955
 
956
  if len(time_frames) > 0:
957
  # Add threshold lines using model-specific thresholds
@@ -972,10 +964,10 @@ def create_realtime_plot(audio_data: np.ndarray, vad_results: List[VADResult],
972
  yref="y4" # Reference to secondary y-axis of second subplot
973
  )
974
 
975
- # Add threshold annotations with model-specific values
976
  fig.add_annotation(
977
  x=time_frames[-1] * 0.95, y=thr_a,
978
- text=f'Threshold: {thr_a:.2f}',
979
  showarrow=False,
980
  font=dict(color='cyan', size=10),
981
  row=1, col=1,
@@ -983,7 +975,7 @@ def create_realtime_plot(audio_data: np.ndarray, vad_results: List[VADResult],
983
  )
984
  fig.add_annotation(
985
  x=time_frames[-1] * 0.95, y=thr_b,
986
- text=f'Threshold: {thr_b:.2f}',
987
  showarrow=False,
988
  font=dict(color='cyan', size=10),
989
  row=2, col=1,
@@ -1191,7 +1183,7 @@ class VADDemo:
1191
  if model_name in self.models:
1192
  window_size = self.processor.model_windows[model_name]
1193
  hop_size = self.processor.model_hop_sizes[model_name]
1194
- model_threshold = self.processor.model_thresholds.get(model_name, threshold)
1195
 
1196
  window_samples = int(self.processor.sample_rate * window_size)
1197
  hop_samples = int(self.processor.sample_rate * hop_size)
@@ -1251,8 +1243,10 @@ class VADDemo:
1251
 
1252
  delay_compensation = self.processor.estimate_delay_compensation(processed_audio, vad_results)
1253
 
1254
- # CAMBIO 4: Use exact threshold crossing detection
1255
- onsets_offsets = self.processor.detect_onset_offset_advanced(vad_results, self.processor.model_thresholds)
 
 
1256
 
1257
  debug_info.append(f"\n🎭 **EVENTS**: {len(onsets_offsets)} onset/offset pairs detected")
1258
 
@@ -1282,18 +1276,15 @@ class VADDemo:
1282
  if result.is_speech:
1283
  summary['speech_chunks'] += 1
1284
 
1285
- # Show model-specific thresholds
1286
- thr_a = self.processor.model_thresholds.get(model_a, threshold)
1287
- thr_b = self.processor.model_thresholds.get(model_b, threshold)
1288
- details_lines = [f"**Analysis Results** (Thresholds → {model_a}:{thr_a:.2f} | {model_b}:{thr_b:.2f})"]
1289
 
1290
  for model_name, summary in model_summaries.items():
1291
  avg_prob = np.mean(summary['probs']) if summary['probs'] else 0
1292
  speech_ratio = (summary['speech_chunks'] / summary['total_chunks']) if summary['total_chunks'] > 0 else 0
1293
- model_thresh = self.processor.model_thresholds.get(model_name, threshold)
1294
 
1295
  status_icon = "🟢" if speech_ratio > 0.5 else "🟡" if speech_ratio > 0.2 else "🔴"
1296
- details_lines.append(f"{status_icon} **{model_name}**: {avg_prob:.3f} avg prob, {speech_ratio*100:.1f}% speech (thresh: {model_thresh:.2f})")
1297
 
1298
  if onsets_offsets:
1299
  details_lines.append(f"\n**Speech Events**: {len(onsets_offsets)} detected")
@@ -1383,7 +1374,7 @@ def create_interface():
1383
  maximum=1.0,
1384
  value=0.5,
1385
  step=0.01,
1386
- label="Global Detection Threshold (Reference Only)"
1387
  )
1388
 
1389
  process_btn = gr.Button("🎤 Analyze", variant="primary", size="lg")
@@ -1421,7 +1412,7 @@ def create_interface():
1421
  ---
1422
  **Models**: Silero-VAD, WebRTC-VAD, E-PANNs, PANNs, AST | **Research**: WASPAA 2025 | **Institution**: University of Surrey, CVSSP
1423
 
1424
- **Note**: All models now provide high temporal resolution (20Hz) for accurate real-time speech detection.
1425
  """)
1426
 
1427
  return interface
 
522
  self.model = self.model.half()
523
  print(f"✅ {self.model_name} loaded with FP16 optimization")
524
  else:
525
+ # Apply quantization for CPU acceleration
526
+ import torch.nn as nn
527
+ self.model = torch.quantization.quantize_dynamic(
528
+ self.model, {nn.Linear}, dtype=torch.qint8
529
+ )
530
+ print(f"✅ {self.model_name} loaded with CPU quantization")
531
 
532
  self.model.eval()
533
  else:
 
670
  "WebRTC-VAD": 0.03, # 30ms frames (480 samples)
671
  "E-PANNs": 1.0, # CHANGED from 6.0 to 1.0 for better temporal resolution
672
  "PANNs": 1.0, # CHANGED from 10.0 to 1.0 for better temporal resolution
673
+ "AST": 0.96 # OPTIMIZED: Natural window size for AST
674
  }
675
 
676
+ # Model-specific hop sizes for efficiency - OPTIMIZED for performance
677
  self.model_hop_sizes = {
678
  "Silero-VAD": 0.016, # 16ms hop for Silero (512 samples window)
679
  "WebRTC-VAD": 0.03, # 30ms hop for WebRTC (match frame duration)
680
  "E-PANNs": 0.05, # CHANGED from 0.1 to 0.05 for 20Hz
681
  "PANNs": 0.05, # CHANGED from 0.1 to 0.05 for 20Hz
682
+ "AST": 0.24 # OPTIMIZED: Reduced frequency (4.17 Hz) for performance
683
  }
684
 
685
  # Model-specific thresholds for better detection
 
729
  hop_length=self.hop_length,
730
  win_length=self.n_fft,
731
  window='hann',
732
+ center=False # CAMBIO: False para tiempo real sin padding
733
  )
734
 
735
  power_spec = np.abs(stft) ** 2
 
786
  return dummy_spec, dummy_time
787
 
788
  def detect_onset_offset_advanced(self, vad_results: List[VADResult],
789
+ threshold: float,
790
+ apply_delay: float = 0.0,
791
+ min_duration: float = 0.12) -> List[OnsetOffset]:
792
  """
793
+ Cruces exactos de umbral global, con compensación de delay y filtro de duración mínima.
794
  Onset: p[i-1] < thr y p[i] >= thr
795
  Offset: p[i-1] >= thr y p[i] < thr
796
  El instante se obtiene por interpolación lineal entre (t[i-1], p[i-1]) y (t[i], p[i]).
797
  """
798
+ onsets_offsets = []
799
  if len(vad_results) < 2:
800
  return onsets_offsets
801
 
802
+ # agrupar por modelo
803
+ grouped = {}
804
  for r in vad_results:
805
  base = r.model_name.split('(')[0].strip()
806
+ # aplica delay al guardar
807
+ grouped.setdefault(base, []).append(
808
+ VADResult(r.probability, r.is_speech, base, r.processing_time, r.timestamp - apply_delay)
809
+ )
810
 
811
  for base, rs in grouped.items():
812
  rs.sort(key=lambda r: r.timestamp)
813
  t = np.array([r.timestamp for r in rs], dtype=float)
814
  p = np.array([r.probability for r in rs], dtype=float)
815
+ thr = float(threshold)
816
 
817
  in_seg = False
818
  onset_t = None
819
 
 
820
  if p[0] > thr:
821
  in_seg = True
822
  onset_t = t[0]
823
 
824
+ def xcross(t0, p0, t1, p1, thr):
825
+ if p1 == p0: return t1
826
+ alpha = (thr - p0) / (p1 - p0)
827
+ return t0 + alpha * (t1 - t0)
828
+
829
  for i in range(1, len(p)):
830
  p0, p1 = p[i-1], p[i]
831
  t0, t1 = t[i-1], t[i]
 
 
832
  if (not in_seg) and (p0 < thr) and (p1 >= thr):
833
+ onset_t = xcross(t0, p0, t1, p1, thr)
 
 
 
 
 
834
  in_seg = True
 
 
835
  elif in_seg and (p0 >= thr) and (p1 < thr):
836
+ off = xcross(t0, p0, t1, p1, thr)
837
+ if off - onset_t >= min_duration: # debounce
838
+ mask = (t >= onset_t) & (t <= off)
839
+ conf = float(p[mask].mean()) if np.any(mask) else float(max(p0, p1))
840
+ onsets_offsets.append(OnsetOffset(max(0.0, float(onset_t)), float(off), base, conf))
 
 
 
 
 
 
 
 
 
841
  in_seg = False
842
  onset_t = None
843
 
 
844
  if in_seg and onset_t is not None:
845
+ off = float(t[-1])
846
+ if off - onset_t >= min_duration:
847
+ mask = (t >= onset_t)
848
+ conf = float(p[mask].mean()) if np.any(mask) else float(p[-1])
849
+ onsets_offsets.append(OnsetOffset(max(0.0, float(onset_t)), off, base, conf))
 
 
 
850
 
851
  return onsets_offsets
852
 
 
941
  row=2, col=1
942
  )
943
 
944
+ # Use global threshold for both models
945
+ thr_a = threshold
946
+ thr_b = threshold
947
 
948
  if len(time_frames) > 0:
949
  # Add threshold lines using model-specific thresholds
 
964
  yref="y4" # Reference to secondary y-axis of second subplot
965
  )
966
 
967
+ # Add threshold annotations with global threshold
968
  fig.add_annotation(
969
  x=time_frames[-1] * 0.95, y=thr_a,
970
+ text=f'Threshold: {threshold:.2f}',
971
  showarrow=False,
972
  font=dict(color='cyan', size=10),
973
  row=1, col=1,
 
975
  )
976
  fig.add_annotation(
977
  x=time_frames[-1] * 0.95, y=thr_b,
978
+ text=f'Threshold: {threshold:.2f}',
979
  showarrow=False,
980
  font=dict(color='cyan', size=10),
981
  row=2, col=1,
 
1183
  if model_name in self.models:
1184
  window_size = self.processor.model_windows[model_name]
1185
  hop_size = self.processor.model_hop_sizes[model_name]
1186
+ model_threshold = threshold # CORRECTED: Use global threshold from slider
1187
 
1188
  window_samples = int(self.processor.sample_rate * window_size)
1189
  hop_samples = int(self.processor.sample_rate * hop_size)
 
1243
 
1244
  delay_compensation = self.processor.estimate_delay_compensation(processed_audio, vad_results)
1245
 
1246
+ # CORRECTED: Use global threshold with delay compensation and min duration
1247
+ onsets_offsets = self.processor.detect_onset_offset_advanced(
1248
+ vad_results, threshold, apply_delay=delay_compensation, min_duration=0.12
1249
+ )
1250
 
1251
  debug_info.append(f"\n🎭 **EVENTS**: {len(onsets_offsets)} onset/offset pairs detected")
1252
 
 
1276
  if result.is_speech:
1277
  summary['speech_chunks'] += 1
1278
 
1279
+ # Show global threshold in analysis results
1280
+ details_lines = [f"**Analysis Results** (Global Threshold: {threshold:.2f})"]
 
 
1281
 
1282
  for model_name, summary in model_summaries.items():
1283
  avg_prob = np.mean(summary['probs']) if summary['probs'] else 0
1284
  speech_ratio = (summary['speech_chunks'] / summary['total_chunks']) if summary['total_chunks'] > 0 else 0
 
1285
 
1286
  status_icon = "🟢" if speech_ratio > 0.5 else "🟡" if speech_ratio > 0.2 else "🔴"
1287
+ details_lines.append(f"{status_icon} **{model_name}**: {avg_prob:.3f} avg prob, {speech_ratio*100:.1f}% speech")
1288
 
1289
  if onsets_offsets:
1290
  details_lines.append(f"\n**Speech Events**: {len(onsets_offsets)} detected")
 
1374
  maximum=1.0,
1375
  value=0.5,
1376
  step=0.01,
1377
+ label="Detection Threshold (Global)"
1378
  )
1379
 
1380
  process_btn = gr.Button("🎤 Analyze", variant="primary", size="lg")
 
1412
  ---
1413
  **Models**: Silero-VAD, WebRTC-VAD, E-PANNs, PANNs, AST | **Research**: WASPAA 2025 | **Institution**: University of Surrey, CVSSP
1414
 
1415
+ **Note**: Optimized for real-time performance with global threshold control and exact temporal alignment.
1416
  """)
1417
 
1418
  return interface