Gabriel Bibbó commited on
Commit
d7e6fe4
·
1 Parent(s): d758548

fix: ajustes en app.py

Browse files
Files changed (1) hide show
  1. app.py +222 -291
app.py CHANGED
@@ -1,177 +1,3 @@
1
- from __future__ import annotations # pospone la evaluación de las anotaciones
2
- import numpy as np # hace visible np para el resto del módulo
3
-
4
- def predict(self, audio: np.ndarray, timestamp: float = 0.0, full_audio: np.ndarray = None) -> VADResult:
5
- start_time = time.time()
6
-
7
- if self.model is None or len(audio) == 0:
8
- # Enhanced fallback using spectral features
9
- if len(audio) > 0:
10
- energy = np.sum(audio ** 2)
11
- if LIBROSA_AVAILABLE:
12
- spectral_features = librosa.feature.spectral_rolloff(y=audio, sr=self.sample_rate)
13
- spectral_centroid = np.mean(librosa.feature.spectral_centroid(y=audio, sr=self.sample_rate))
14
- probability = min((energy * 100 + spectral_centroid / 1000) / 2, 1.0)
15
- else:
16
- probability = min(energy * 50, 1.0)
17
- is_speech = probability > 0.25
18
- else:
19
- probability = 0.0
20
- is_speech = False
21
- return VADResult(probability, is_speech, f"{self.model_name} (fallback)", time.time() - start_time, timestamp)
22
-
23
- try:
24
- # Cache key based on timestamp rounded to cache window
25
- cache_key = int(timestamp / self.cache_window)
26
-
27
- # Check cache first
28
- if cache_key in self.prediction_cache:
29
- cached_result = self.prediction_cache[cache_key]
30
- # Return cached result with updated timestamp
31
- return VADResult(
32
- cached_result.probability,
33
- cached_result.is_speech,
34
- cached_result.model_name + " (cached)",
35
- time.time() - start_time,
36
- timestamp
37
- )
38
-
39
- if len(audio.shape) > 1:
40
- audio = audio.mean(axis=1)
41
-
42
- # Use longer context for AST - preferably 6.4 seconds (1024 frames)
43
- window_duration = 6.4 # seconds
44
- window_samples = int(window_duration * self.sample_rate)
45
-
46
- # If full_audio is provided, use it for better context
47
- if full_audio is not None and len(full_audio) > window_samples:
48
- # Take window centered around current timestamp
49
- center_pos = int(timestamp * self.sample_rate)
50
- half_window = window_samples // 2
51
-
52
- start_pos = max(0, center_pos - half_window)
53
- end_pos = min(len(full_audio), start_pos + window_samples)
54
-
55
- # Adjust if at the end of audio
56
- if end_pos == len(full_audio) and end_pos - start_pos < window_samples:
57
- start_pos = max(0, end_pos - window_samples)
58
-
59
- audio_for_ast = full_audio[start_pos:end_pos]
60
- else:
61
- # Extract window from provided audio based on timestamp
62
- center_sample = int(timestamp * self.sample_rate)
63
- half_window = window_samples // 2
64
-
65
- start_idx = max(0, center_sample - half_window)
66
- end_idx = min(len(audio), start_idx + window_samples)
67
-
68
- # Adjust if at the end
69
- if end_idx == len(audio) and end_idx - start_idx < window_samples:
70
- start_idx = max(0, end_idx - window_samples)
71
-
72
- audio_for_ast = audio[start_idx:end_idx]
73
-
74
- # For short audio, use intelligent strategy
75
- min_samples = int(6.4 * self.sample_rate) # 6.4 seconds
76
- if len(audio_for_ast) < min_samples:
77
- # Repeat the audio cyclically to maintain temporal patterns
78
- num_repeats = int(np.ceil(min_samples / len(audio_for_ast)))
79
- audio_repeated = np.tile(audio_for_ast, num_repeats)[:min_samples]
80
-
81
- # Apply smooth transitions at repetition boundaries
82
- fade_samples = int(0.01 * self.sample_rate) # 10ms fade
83
- for i in range(1, num_repeats):
84
- if i * len(audio_for_ast) < len(audio_repeated):
85
- start_idx = i * len(audio_for_ast) - fade_samples
86
- end_idx = i * len(audio_for_ast) + fade_samples
87
- if start_idx >= 0 and end_idx < len(audio_repeated):
88
- audio_repeated[start_idx:end_idx] *= np.linspace(1, 1, 2 * fade_samples)
89
-
90
- audio_for_ast = audio_repeated
91
-
92
- # Truncate if too long
93
- max_samples = 8 * self.sample_rate
94
- if len(audio_for_ast) > max_samples:
95
- audio_for_ast = audio_for_ast[:max_samples]
96
-
97
- # Feature extraction
98
- inputs = self.feature_extractor(
99
- audio_for_ast,
100
- sampling_rate=self.sample_rate,
101
- return_tensors="pt",
102
- max_length=1024,
103
- padding="max_length",
104
- truncation=True
105
- )
106
-
107
- # Move inputs to correct device and dtype
108
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
109
- if self.device.type == 'cuda' and hasattr(self.model, 'half'):
110
- inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in inputs.items()}
111
-
112
- with torch.no_grad():
113
- outputs = self.model(**inputs)
114
- logits = outputs.logits
115
- probs = torch.sigmoid(logits)
116
-
117
- # Find speech-related classes
118
- label2id = self.model.config.label2id
119
- speech_indices = []
120
- speech_keywords = [
121
- 'speech', 'voice', 'talk', 'conversation', 'speaking',
122
- 'male speech', 'female speech', 'child speech',
123
- 'speech synthesizer', 'narration'
124
- ]
125
-
126
- for lbl, idx in label2id.items():
127
- if any(word in lbl.lower() for word in speech_keywords):
128
- speech_indices.append(idx)
129
-
130
- # Also identify background/noise classes
131
- noise_keywords = ['silence', 'white noise', 'background']
132
- noise_indices = []
133
- for lbl, idx in label2id.items():
134
- if any(word in lbl.lower() for word in noise_keywords):
135
- noise_indices.append(idx)
136
-
137
- if speech_indices:
138
- # Use max probability among speech classes
139
- speech_probs = probs[0, speech_indices]
140
- speech_prob = torch.max(speech_probs).item()
141
-
142
- # Consider noise/silence probability
143
- if noise_indices:
144
- noise_prob = torch.mean(probs[0, noise_indices]).item()
145
- speech_prob = speech_prob * (1 - noise_prob * 0.3)
146
-
147
- # Adjust confidence for short audio
148
- if len(audio) < self.sample_rate * 2:
149
- confidence_factor = len(audio) / (self.sample_rate * 2)
150
- speech_prob = speech_prob * (0.6 + 0.4 * confidence_factor)
151
-
152
- # ── FIN DEL CÁLCULO DENTRO DE try ──────────────────────────
153
- is_speech_ast = speech_prob > 0.25
154
- return VADResult(
155
- float(speech_prob),
156
- is_speech_ast,
157
- self.model_name,
158
- time.time() - start_time,
159
- timestamp
160
- )
161
-
162
- except Exception as e:
163
- print(f"❌ AST ERROR: {e}")
164
- import traceback
165
- traceback.print_exc()
166
- return VADResult(
167
- 0.0,
168
- False,
169
- f"{self.model_name} (error)",
170
- time.time() - start_time,
171
- timestamp
172
- )
173
-
174
-
175
  import gradio as gr
176
  import numpy as np
177
  import torch
@@ -243,14 +69,22 @@ except ImportError:
243
  PLOTLY_AVAILABLE = False
244
  print("⚠️ Plotly not available")
245
 
246
- # PANNs imports
247
  try:
248
- from panns_inference import AudioTagging, labels
249
  PANNS_AVAILABLE = True
250
- print("✅ PANNs available")
 
251
  except ImportError:
252
- PANNS_AVAILABLE = False
253
- print("⚠️ PANNs not available, using fallback")
 
 
 
 
 
 
 
254
 
255
  # Transformers for AST
256
  try:
@@ -264,6 +98,25 @@ except ImportError:
264
 
265
  print("🚀 Creating Real-time VAD Demo...")
266
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  # ===== DATA STRUCTURES =====
268
 
269
  @dataclass
@@ -403,10 +256,20 @@ class OptimizedWebRTCVAD:
403
  return VADResult(0.0, False, self.model_name, time.time() - start_time, timestamp)
404
 
405
  class OptimizedEPANNs:
 
406
  def __init__(self):
407
  self.model_name = "E-PANNs"
408
  self.sample_rate = 32000
409
  print(f"✅ {self.model_name} initialized")
 
 
 
 
 
 
 
 
 
410
 
411
  def predict(self, audio: np.ndarray, timestamp: float = 0.0) -> VADResult:
412
  start_time = time.time()
@@ -436,7 +299,7 @@ class OptimizedEPANNs:
436
 
437
  audio_window = audio[start_idx:end_idx]
438
 
439
- # Convert audio to target sample rate for E-PANNs
440
  if LIBROSA_AVAILABLE:
441
  # Resample to E-PANNs sample rate
442
  audio_resampled = librosa.resample(audio_window.astype(float),
@@ -450,30 +313,54 @@ class OptimizedEPANNs:
450
  num_repeats = int(np.ceil(min_samples / len(audio_resampled)))
451
  audio_resampled = np.tile(audio_resampled, num_repeats)[:min_samples]
452
 
453
- # Compute features
454
- mel_spec = librosa.feature.melspectrogram(y=audio_resampled, sr=self.sample_rate, n_mels=64)
455
- energy = np.mean(librosa.power_to_db(mel_spec, ref=np.max))
456
-
457
- # Use actual non-repeated audio for some features
458
- actual_audio_len = min(len(audio_resampled), int(len(audio_window) * self.sample_rate / 16000))
459
- actual_audio = audio_resampled[:actual_audio_len]
460
-
461
- spectral_centroid = np.mean(librosa.feature.spectral_centroid(y=actual_audio, sr=self.sample_rate))
462
- mfcc = librosa.feature.mfcc(y=actual_audio, sr=self.sample_rate, n_mfcc=13)
463
- mfcc_var = np.var(mfcc, axis=1).mean()
464
- zcr = np.mean(librosa.feature.zero_crossing_rate(actual_audio))
465
-
466
- # Adjusted scaling for better speech detection
467
- energy_score = np.clip((energy + 80) / 40, 0, 1)
468
- centroid_score = np.clip((spectral_centroid - 200) / 3000, 0, 1)
469
- mfcc_score = np.clip(mfcc_var / 100, 0, 1)
470
- zcr_score = np.clip(zcr * 10, 0, 1)
471
-
472
- # Weighted combination
473
- speech_score = (energy_score * 0.4 +
474
- centroid_score * 0.2 +
475
- mfcc_score * 0.3 +
476
- zcr_score * 0.1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
477
  else:
478
  from scipy import signal
479
  # Basic fallback without librosa
@@ -493,29 +380,44 @@ class OptimizedEPANNs:
493
  return VADResult(0.0, False, self.model_name, time.time() - start_time, timestamp)
494
 
495
  class OptimizedPANNs:
 
496
  def __init__(self):
497
  self.model_name = "PANNs"
498
  self.sample_rate = 32000
499
  self.model = None
 
500
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
501
  self.load_model()
502
 
503
  def load_model(self):
504
  try:
505
  if PANNS_AVAILABLE:
506
- self.model = AudioTagging(checkpoint_path=None, device=self.device)
507
- print(f"✅ {self.model_name} loaded successfully")
 
 
 
 
 
 
 
 
 
 
 
508
  else:
509
  print(f"⚠️ {self.model_name} not available, using fallback")
510
  self.model = None
 
511
  except Exception as e:
512
  print(f"❌ Error loading {self.model_name}: {e}")
513
  self.model = None
 
514
 
515
  def predict(self, audio: np.ndarray, timestamp: float = 0.0) -> VADResult:
516
  start_time = time.time()
517
 
518
- if self.model is None or len(audio) == 0:
519
  if len(audio) > 0:
520
  energy = np.sum(audio ** 2)
521
  threshold = 0.01
@@ -579,48 +481,86 @@ class OptimizedPANNs:
579
 
580
  audio_resampled = audio_repeated
581
 
582
- # Run inference
583
- clip_probs, _ = self.model.inference(audio_resampled[np.newaxis, :])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
584
 
585
- # Enhanced speech detection using multiple relevant labels
586
- speech_keywords = [
587
- 'speech', 'voice', 'talk', 'conversation', 'speaking',
588
- 'male speech', 'female speech', 'child speech',
589
- 'narration', 'monologue'
590
- ]
591
-
592
- speech_indices = []
593
- for i, lbl in enumerate(labels):
594
- if any(word in lbl.lower() for word in speech_keywords):
595
- speech_indices.append(i)
596
-
597
- # Also get silence/noise indices for contrast
598
- noise_keywords = ['silence', 'white noise', 'pink noise']
599
- noise_indices = []
600
- for i, lbl in enumerate(labels):
601
- if any(word in lbl.lower() for word in noise_keywords):
602
- noise_indices.append(i)
603
-
604
- if speech_indices:
605
- # Get speech probability
606
- speech_probs = clip_probs[0, speech_indices]
607
- speech_prob = np.max(speech_probs) # Use max instead of mean for better detection
608
 
609
- # Get noise probability for contrast
610
- if noise_indices:
611
- noise_prob = np.mean(clip_probs[0, noise_indices])
612
- # Adjust speech probability based on noise
613
- speech_prob = speech_prob * (1 - noise_prob * 0.5)
 
 
 
 
 
 
614
 
615
- # If using repeated audio, scale confidence based on original length
616
- if len(audio_window) < 16000 * 2: # Less than 2 seconds
617
- confidence_scale = len(audio_window) / (16000 * 2)
618
- speech_prob = speech_prob * (0.5 + 0.5 * confidence_scale)
619
 
620
- else:
621
- # Fallback if no speech indices found
622
- top_indices = np.argsort(clip_probs[0])[-10:]
623
- speech_prob = np.mean(clip_probs[0, top_indices])
 
 
 
 
 
 
 
 
 
 
 
624
 
625
  return VADResult(float(speech_prob), speech_prob > 0.4, self.model_name, time.time()-start_time, timestamp)
626
 
@@ -639,9 +579,10 @@ class OptimizedPANNs:
639
  return VADResult(probability, is_speech, f"{self.model_name} (error)", time.time() - start_time, timestamp)
640
 
641
  class OptimizedAST:
 
642
  def __init__(self):
643
  self.model_name = "AST"
644
- self.sample_rate = 16000
645
  self.model = None
646
  self.feature_extractor = None
647
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -720,59 +661,49 @@ class OptimizedAST:
720
  audio = audio.mean(axis=1)
721
  print(f"🔄 AST: Converted to mono")
722
 
723
- # Use longer context for AST - preferably 6.4 seconds (1024 frames)
724
- if full_audio is not None and len(full_audio) >= 6.4 * self.sample_rate:
725
- print(f"✅ AST: Using full audio context")
726
- # Take 6.4-second window centered around current timestamp
727
- center_pos = int(timestamp * self.sample_rate)
728
- window_size = int(3.2 * self.sample_rate) # 3.2 seconds each side
729
-
730
- start_pos = max(0, center_pos - window_size)
731
- end_pos = min(len(full_audio), center_pos + window_size)
732
-
733
- # Ensure we have at least 6.4 seconds
734
- if end_pos - start_pos < 6.4 * self.sample_rate:
735
- end_pos = min(len(full_audio), start_pos + int(6.4 * self.sample_rate))
736
- if end_pos - start_pos < 6.4 * self.sample_rate:
737
- start_pos = max(0, end_pos - int(6.4 * self.sample_rate))
738
-
739
- audio_for_ast = full_audio[start_pos:end_pos]
740
- print(f"🔄 AST: Extracted window [{start_pos}:{end_pos}], len={len(audio_for_ast)}")
741
- else:
742
- print(f"⚠️ AST: Using provided audio chunk")
743
- audio_for_ast = audio
744
-
745
  # For short audio, use intelligent strategy
746
- min_samples = int(6.4 * self.sample_rate) # 6.4 seconds
747
  if len(audio_for_ast) < min_samples:
748
- print(f"⚠️ AST: Audio too short ({len(audio_for_ast)} samples), using cyclic repetition")
749
- # Repeat the audio cyclically to maintain temporal patterns
750
- num_repeats = int(np.ceil(min_samples / len(audio_for_ast)))
751
- audio_repeated = np.tile(audio_for_ast, num_repeats)[:min_samples]
752
-
753
- # Apply smooth transitions at repetition boundaries
754
- fade_samples = int(0.01 * self.sample_rate) # 10ms fade
755
- for i in range(1, num_repeats):
756
- if i * len(audio_for_ast) < len(audio_repeated):
757
- start_idx = i * len(audio_for_ast) - fade_samples
758
- end_idx = i * len(audio_for_ast) + fade_samples
759
- if start_idx >= 0 and end_idx < len(audio_repeated):
760
- audio_repeated[start_idx:end_idx] *= np.linspace(1, 1, 2 * fade_samples)
761
-
762
- audio_for_ast = audio_repeated
763
- print(f"✅ AST: Repeated with smoothing, final_len={len(audio_for_ast)}")
764
-
765
- # Truncate if too long (AST can handle up to ~10s, but we'll use 8s max for efficiency)
766
- max_samples = 8 * self.sample_rate
767
  if len(audio_for_ast) > max_samples:
768
  audio_for_ast = audio_for_ast[:max_samples]
769
  print(f"✂️ AST: Truncated to {len(audio_for_ast)} samples")
770
 
771
  print(f"🔄 AST: Feature extraction...")
772
- # Feature extraction with proper AST parameters (closer to 1024 frames)
773
  inputs = self.feature_extractor(
774
  audio_for_ast,
775
- sampling_rate=self.sample_rate,
776
  return_tensors="pt",
777
  max_length=1024, # Proper AST context
778
  padding="max_length", # Ensure consistent length
@@ -896,7 +827,7 @@ class AudioProcessor:
896
  "WebRTC-VAD": 0.03, # 30ms frames (480 samples)
897
  "E-PANNs": 6.0, # 6 seconds minimum for reliable results
898
  "PANNs": 10.0, # 10 seconds for optimal performance
899
- "AST": 6.4 # ~6.4 seconds (1024 frames * 6.25ms)
900
  }
901
 
902
  # Model-specific hop sizes for efficiency
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import numpy as np
3
  import torch
 
69
  PLOTLY_AVAILABLE = False
70
  print("⚠️ Plotly not available")
71
 
72
+ # PANNs imports - UPDATED to include SoundEventDetection
73
  try:
74
+ from panns_inference import AudioTagging, SoundEventDetection, labels
75
  PANNS_AVAILABLE = True
76
+ PANNS_SED_AVAILABLE = True
77
+ print("✅ PANNs available with SoundEventDetection")
78
  except ImportError:
79
+ try:
80
+ from panns_inference import AudioTagging, labels
81
+ PANNS_AVAILABLE = True
82
+ PANNS_SED_AVAILABLE = False
83
+ print("✅ PANNs available (AudioTagging only)")
84
+ except ImportError:
85
+ PANNS_AVAILABLE = False
86
+ PANNS_SED_AVAILABLE = False
87
+ print("⚠️ PANNs not available, using fallback")
88
 
89
  # Transformers for AST
90
  try:
 
98
 
99
  print("🚀 Creating Real-time VAD Demo...")
100
 
101
+ # ===== HELPER FUNCTIONS FOR CORRECTED MODELS =====
102
+ def safe_resample(x, sr_in, sr_out):
103
+ """Safely resample audio from sr_in to sr_out"""
104
+ if sr_in == sr_out:
105
+ return x.astype(np.float32)
106
+ try:
107
+ if LIBROSA_AVAILABLE:
108
+ return librosa.resample(x.astype(float), orig_sr=sr_in, target_sr=sr_out)
109
+ else:
110
+ # Fallback linear interpolation
111
+ dur = len(x) / sr_in
112
+ n_out = max(1, int(round(dur * sr_out)))
113
+ xi = np.linspace(0, len(x)-1, num=len(x))
114
+ xo = np.linspace(0, len(x)-1, num=n_out)
115
+ return np.interp(xo, xi, x).astype(np.float32)
116
+ except Exception as e:
117
+ print(f"Resample error: {e}")
118
+ return x.astype(np.float32)
119
+
120
  # ===== DATA STRUCTURES =====
121
 
122
  @dataclass
 
256
  return VADResult(0.0, False, self.model_name, time.time() - start_time, timestamp)
257
 
258
  class OptimizedEPANNs:
259
+ """CORRECTED E-PANNs with proper temporal resolution using sliding windows"""
260
  def __init__(self):
261
  self.model_name = "E-PANNs"
262
  self.sample_rate = 32000
263
  print(f"✅ {self.model_name} initialized")
264
+
265
+ # Try to load PANNs AudioTagging as backend for E-PANNs
266
+ self.at_model = None
267
+ if PANNS_AVAILABLE:
268
+ try:
269
+ self.at_model = AudioTagging(checkpoint_path=None, device='cpu')
270
+ print(f"✅ {self.model_name} using PANNs AT backend")
271
+ except Exception as e:
272
+ print(f"⚠️ {self.model_name} PANNs AT unavailable: {e}")
273
 
274
  def predict(self, audio: np.ndarray, timestamp: float = 0.0) -> VADResult:
275
  start_time = time.time()
 
299
 
300
  audio_window = audio[start_idx:end_idx]
301
 
302
+ # Convert audio to target sample rate for E-PANNs (32kHz)
303
  if LIBROSA_AVAILABLE:
304
  # Resample to E-PANNs sample rate
305
  audio_resampled = librosa.resample(audio_window.astype(float),
 
313
  num_repeats = int(np.ceil(min_samples / len(audio_resampled)))
314
  audio_resampled = np.tile(audio_resampled, num_repeats)[:min_samples]
315
 
316
+ # If we have PANNs AT model, use it
317
+ if self.at_model is not None:
318
+ # Run inference
319
+ clipwise_output, _ = self.at_model.inference(audio_resampled[np.newaxis, :])
320
+
321
+ # Get speech-related classes
322
+ speech_keywords = [
323
+ 'speech', 'voice', 'talk', 'conversation', 'speaking',
324
+ 'male speech', 'female speech', 'child speech',
325
+ 'narration', 'monologue'
326
+ ]
327
+
328
+ speech_indices = []
329
+ for i, lbl in enumerate(labels):
330
+ if any(word in lbl.lower() for word in speech_keywords):
331
+ speech_indices.append(i)
332
+
333
+ if speech_indices:
334
+ speech_probs = clipwise_output[0, speech_indices]
335
+ speech_score = float(np.max(speech_probs))
336
+ else:
337
+ speech_score = float(np.max(clipwise_output[0]))
338
+ else:
339
+ # Fallback to spectral features
340
+ # Compute features
341
+ mel_spec = librosa.feature.melspectrogram(y=audio_resampled, sr=self.sample_rate, n_mels=64)
342
+ energy = np.mean(librosa.power_to_db(mel_spec, ref=np.max))
343
+
344
+ # Use actual non-repeated audio for some features
345
+ actual_audio_len = min(len(audio_resampled), int(len(audio_window) * self.sample_rate / 16000))
346
+ actual_audio = audio_resampled[:actual_audio_len]
347
+
348
+ spectral_centroid = np.mean(librosa.feature.spectral_centroid(y=actual_audio, sr=self.sample_rate))
349
+ mfcc = librosa.feature.mfcc(y=actual_audio, sr=self.sample_rate, n_mfcc=13)
350
+ mfcc_var = np.var(mfcc, axis=1).mean()
351
+ zcr = np.mean(librosa.feature.zero_crossing_rate(actual_audio))
352
+
353
+ # Adjusted scaling for better speech detection
354
+ energy_score = np.clip((energy + 80) / 40, 0, 1)
355
+ centroid_score = np.clip((spectral_centroid - 200) / 3000, 0, 1)
356
+ mfcc_score = np.clip(mfcc_var / 100, 0, 1)
357
+ zcr_score = np.clip(zcr * 10, 0, 1)
358
+
359
+ # Weighted combination
360
+ speech_score = (energy_score * 0.4 +
361
+ centroid_score * 0.2 +
362
+ mfcc_score * 0.3 +
363
+ zcr_score * 0.1)
364
  else:
365
  from scipy import signal
366
  # Basic fallback without librosa
 
380
  return VADResult(0.0, False, self.model_name, time.time() - start_time, timestamp)
381
 
382
  class OptimizedPANNs:
383
+ """CORRECTED PANNs with SoundEventDetection for framewise output when available"""
384
  def __init__(self):
385
  self.model_name = "PANNs"
386
  self.sample_rate = 32000
387
  self.model = None
388
+ self.sed_model = None
389
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
390
  self.load_model()
391
 
392
  def load_model(self):
393
  try:
394
  if PANNS_AVAILABLE:
395
+ # Try to load SED model first for framewise output
396
+ if PANNS_SED_AVAILABLE:
397
+ try:
398
+ self.sed_model = SoundEventDetection(checkpoint_path=None, device=self.device)
399
+ print(f"✅ {self.model_name} SED loaded successfully (framewise mode)")
400
+ except Exception as e:
401
+ print(f"⚠️ {self.model_name} SED initialization failed: {e}")
402
+ self.sed_model = None
403
+
404
+ # Load AudioTagging as fallback or primary
405
+ if self.sed_model is None:
406
+ self.model = AudioTagging(checkpoint_path=None, device=self.device)
407
+ print(f"✅ {self.model_name} AT loaded successfully")
408
  else:
409
  print(f"⚠️ {self.model_name} not available, using fallback")
410
  self.model = None
411
+ self.sed_model = None
412
  except Exception as e:
413
  print(f"❌ Error loading {self.model_name}: {e}")
414
  self.model = None
415
+ self.sed_model = None
416
 
417
  def predict(self, audio: np.ndarray, timestamp: float = 0.0) -> VADResult:
418
  start_time = time.time()
419
 
420
+ if (self.model is None and self.sed_model is None) or len(audio) == 0:
421
  if len(audio) > 0:
422
  energy = np.sum(audio ** 2)
423
  threshold = 0.01
 
481
 
482
  audio_resampled = audio_repeated
483
 
484
+ # Use SED for framewise predictions if available
485
+ if self.sed_model is not None:
486
+ # SED gives framewise output
487
+ framewise_output = self.sed_model.inference(audio_resampled[np.newaxis, :])
488
+
489
+ if hasattr(framewise_output, 'cpu'):
490
+ framewise_output = framewise_output.cpu().numpy()
491
+
492
+ if framewise_output.ndim == 3:
493
+ framewise_output = framewise_output[0] # Remove batch dimension
494
+
495
+ # Get frame corresponding to timestamp
496
+ audio_duration = len(audio_resampled) / self.sample_rate
497
+ if audio_duration > 0:
498
+ frame_idx = int((timestamp % audio_duration) / audio_duration * framewise_output.shape[0])
499
+ frame_idx = min(frame_idx, framewise_output.shape[0] - 1)
500
+ else:
501
+ frame_idx = 0
502
+
503
+ # Get speech-related classes
504
+ speech_keywords = [
505
+ 'speech', 'voice', 'talk', 'conversation', 'speaking',
506
+ 'male speech', 'female speech', 'child speech',
507
+ 'narration', 'monologue'
508
+ ]
509
+
510
+ speech_indices = []
511
+ for i, lbl in enumerate(labels):
512
+ if any(word in lbl.lower() for word in speech_keywords):
513
+ speech_indices.append(i)
514
+
515
+ if speech_indices and frame_idx < framewise_output.shape[0]:
516
+ speech_probs = framewise_output[frame_idx, speech_indices]
517
+ speech_prob = float(np.max(speech_probs))
518
+ else:
519
+ speech_prob = float(np.max(framewise_output[frame_idx])) if frame_idx < framewise_output.shape[0] else 0.0
520
+ else:
521
+ # Use AudioTagging model
522
+ # Run inference
523
+ clip_probs, _ = self.model.inference(audio_resampled[np.newaxis, :])
524
 
525
+ # Enhanced speech detection using multiple relevant labels
526
+ speech_keywords = [
527
+ 'speech', 'voice', 'talk', 'conversation', 'speaking',
528
+ 'male speech', 'female speech', 'child speech',
529
+ 'narration', 'monologue'
530
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
531
 
532
+ speech_indices = []
533
+ for i, lbl in enumerate(labels):
534
+ if any(word in lbl.lower() for word in speech_keywords):
535
+ speech_indices.append(i)
536
+
537
+ # Also get silence/noise indices for contrast
538
+ noise_keywords = ['silence', 'white noise', 'pink noise']
539
+ noise_indices = []
540
+ for i, lbl in enumerate(labels):
541
+ if any(word in lbl.lower() for word in noise_keywords):
542
+ noise_indices.append(i)
543
 
544
+ if speech_indices:
545
+ # Get speech probability
546
+ speech_probs = clip_probs[0, speech_indices]
547
+ speech_prob = np.max(speech_probs) # Use max instead of mean for better detection
548
 
549
+ # Get noise probability for contrast
550
+ if noise_indices:
551
+ noise_prob = np.mean(clip_probs[0, noise_indices])
552
+ # Adjust speech probability based on noise
553
+ speech_prob = speech_prob * (1 - noise_prob * 0.5)
554
+
555
+ # If using repeated audio, scale confidence based on original length
556
+ if len(audio_window) < 16000 * 2: # Less than 2 seconds
557
+ confidence_scale = len(audio_window) / (16000 * 2)
558
+ speech_prob = speech_prob * (0.5 + 0.5 * confidence_scale)
559
+
560
+ else:
561
+ # Fallback if no speech indices found
562
+ top_indices = np.argsort(clip_probs[0])[-10:]
563
+ speech_prob = np.mean(clip_probs[0, top_indices])
564
 
565
  return VADResult(float(speech_prob), speech_prob > 0.4, self.model_name, time.time()-start_time, timestamp)
566
 
 
579
  return VADResult(probability, is_speech, f"{self.model_name} (error)", time.time() - start_time, timestamp)
580
 
581
  class OptimizedAST:
582
+ """CORRECTED AST with proper 16kHz sample rate and sliding windows"""
583
  def __init__(self):
584
  self.model_name = "AST"
585
+ self.sample_rate = 16000 # AST REQUIRES 16kHz
586
  self.model = None
587
  self.feature_extractor = None
588
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
661
  audio = audio.mean(axis=1)
662
  print(f"🔄 AST: Converted to mono")
663
 
664
+ # CRITICAL FIX: AST uses 16kHz, but input is already at 16kHz
665
+ # So we DON'T need to resample, just ensure it's float32
666
+ audio = audio.astype(np.float32)
667
+
668
+ # Use sliding window approach for temporal resolution
669
+ window_duration = 1.0 # 1 second windows
670
+ window_samples = int(window_duration * self.sample_rate)
671
+
672
+ # Get window for this timestamp
673
+ center_sample = int(timestamp * self.sample_rate)
674
+ half_window = window_samples // 2
675
+
676
+ start_idx = max(0, center_sample - half_window)
677
+ end_idx = min(len(audio), start_idx + window_samples)
678
+
679
+ # Adjust if at the end
680
+ if end_idx == len(audio) and end_idx - start_idx < window_samples:
681
+ start_idx = max(0, end_idx - window_samples)
682
+
683
+ audio_for_ast = audio[start_idx:end_idx]
684
+ print(f"🔄 AST: Extracted window [{start_idx}:{end_idx}], len={len(audio_for_ast)}")
685
+
686
  # For short audio, use intelligent strategy
687
+ min_samples = int(1.0 * self.sample_rate) # 1 second minimum
688
  if len(audio_for_ast) < min_samples:
689
+ print(f"⚠️ AST: Audio too short ({len(audio_for_ast)} samples), padding")
690
+ # Pad with zeros
691
+ audio_padded = np.zeros(min_samples)
692
+ audio_padded[:len(audio_for_ast)] = audio_for_ast
693
+ audio_for_ast = audio_padded
694
+ print(f"✅ AST: Padded to {len(audio_for_ast)} samples")
695
+
696
+ # Truncate if too long (AST can handle up to ~10s, but we use 1s windows)
697
+ max_samples = int(1.5 * self.sample_rate)
 
 
 
 
 
 
 
 
 
 
698
  if len(audio_for_ast) > max_samples:
699
  audio_for_ast = audio_for_ast[:max_samples]
700
  print(f"✂️ AST: Truncated to {len(audio_for_ast)} samples")
701
 
702
  print(f"🔄 AST: Feature extraction...")
703
+ # Feature extraction with proper AST parameters
704
  inputs = self.feature_extractor(
705
  audio_for_ast,
706
+ sampling_rate=self.sample_rate, # Must be 16kHz
707
  return_tensors="pt",
708
  max_length=1024, # Proper AST context
709
  padding="max_length", # Ensure consistent length
 
827
  "WebRTC-VAD": 0.03, # 30ms frames (480 samples)
828
  "E-PANNs": 6.0, # 6 seconds minimum for reliable results
829
  "PANNs": 10.0, # 10 seconds for optimal performance
830
+ "AST": 1.0 # Changed to 1 second for better temporal resolution
831
  }
832
 
833
  # Model-specific hop sizes for efficiency