Gabriel Bibbó commited on
Commit
25b51aa
·
1 Parent(s): 2a8cb45

Fix threshold lines visibility and AST probability detection

Browse files
Files changed (1) hide show
  1. app.py +60 -89
app.py CHANGED
@@ -362,9 +362,6 @@ class OptimizedAST:
362
  self.model = None
363
  self.feature_extractor = None
364
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
365
- # Cache for long audio segments (not tiny chunks)
366
- self.segment_cache = {}
367
- self.min_audio_length = self.sample_rate # 1 second minimum
368
  self.load_model()
369
 
370
  def load_model(self):
@@ -374,8 +371,6 @@ class OptimizedAST:
374
  self.feature_extractor = ASTFeatureExtractor.from_pretrained(model_name)
375
  self.model = ASTForAudioClassification.from_pretrained(model_name)
376
  self.model.to(self.device)
377
- if torch.cuda.is_available():
378
- self.model.half() # Use FP16 for speed
379
  self.model.eval()
380
  print(f"✅ {self.model_name} loaded successfully")
381
  else:
@@ -396,10 +391,10 @@ class OptimizedAST:
396
  spectral_features = librosa.feature.spectral_rolloff(y=audio, sr=self.sample_rate)
397
  spectral_centroid = np.mean(librosa.feature.spectral_centroid(y=audio, sr=self.sample_rate))
398
  # Combine multiple features for better speech detection
399
- probability = min((energy * 50 + spectral_centroid / 1000 + np.mean(spectral_features) / 1000) / 3, 1.0)
400
  else:
401
- probability = min(energy / 0.01, 1.0)
402
- is_speech = probability > 0.3 # Lower threshold for fallback
403
  else:
404
  probability = 0.0
405
  is_speech = False
@@ -409,98 +404,71 @@ class OptimizedAST:
409
  if len(audio.shape) > 1:
410
  audio = audio.mean(axis=1)
411
 
412
- # AST needs much longer context - use at least 1 second
413
- if len(audio) < self.min_audio_length:
414
- # Pad to minimum length (1 second)
415
- audio = np.pad(audio, (0, self.min_audio_length - len(audio)), 'constant')
416
-
417
- # For very long audio, take a representative 2-second segment
418
- if len(audio) > self.sample_rate * 2:
419
- # Take segment around current timestamp from full audio if available
420
- if full_audio is not None and len(full_audio) > self.sample_rate:
421
- # Calculate position in full audio
422
- center_pos = int(timestamp * self.sample_rate)
423
- half_window = self.sample_rate # 1 second each side
424
-
425
- start_pos = max(0, center_pos - half_window)
426
- end_pos = min(len(full_audio), center_pos + half_window)
427
-
428
- # Ensure we have at least 1 second
429
- if end_pos - start_pos < self.min_audio_length:
430
- end_pos = min(len(full_audio), start_pos + self.min_audio_length)
431
-
432
- audio = full_audio[start_pos:end_pos]
433
- else:
434
- # Fallback: take middle part
435
- start_idx = (len(audio) - self.sample_rate * 2) // 2
436
- audio = audio[start_idx:start_idx + self.sample_rate * 2]
437
-
438
- # Create cache key based on timestamp range instead of audio bytes
439
- cache_key = f"{int(timestamp * 10)}" # Cache per 100ms of timestamp
440
-
441
- if cache_key in self.segment_cache:
442
- speech_prob = self.segment_cache[cache_key]
443
- else:
444
- # Feature extraction with proper parameters for AST
445
- inputs = self.feature_extractor(
446
- audio,
447
- sampling_rate=self.sample_rate,
448
- return_tensors="pt",
449
- padding="max_length",
450
- max_length=1024, # Proper context length (~10s worth of frames)
451
- truncation=True
452
- )
453
-
454
- # Move to device and convert to proper dtype
455
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
456
- if torch.cuda.is_available():
457
- inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in inputs.items()}
458
 
459
- with torch.no_grad():
460
- outputs = self.model(**inputs)
461
- logits = outputs.logits
462
- probs = torch.sigmoid(logits)
463
 
464
- # Find speech-related indices with broader search
465
- label2id = self.model.config.label2id
466
- speech_indices = []
467
- speech_keywords = ['speech', 'voice', 'talk', 'conversation', 'speaking', 'human', 'vocal', 'verbal']
468
 
469
- for lbl, idx in label2id.items():
470
- if any(word in lbl.lower() for word in speech_keywords):
471
- speech_indices.append(idx)
472
-
473
- if speech_indices:
474
- speech_prob = probs[0, speech_indices].mean().item()
475
- else:
476
- # Enhanced fallback: look for any human-related audio classes
477
- human_indices = []
478
- for lbl, idx in label2id.items():
479
- if any(word in lbl.lower() for word in ['human', 'people', 'person', 'male', 'female', 'child']):
480
- human_indices.append(idx)
481
-
482
- if human_indices:
483
- speech_prob = probs[0, human_indices].mean().item()
484
- else:
485
- # Last resort: use top activations
486
- speech_prob = probs[0].topk(10).values.mean().item()
487
 
488
- # Cache with limited size
489
- if len(self.segment_cache) < 100:
490
- self.segment_cache[cache_key] = speech_prob
491
- elif len(self.segment_cache) >= 200: # Clear cache when too large
492
- self.segment_cache.clear()
493
 
494
- return VADResult(float(speech_prob), speech_prob > 0.5, self.model_name, time.time()-start_time, timestamp)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
495
 
496
  except Exception as e:
497
  print(f"Error in {self.model_name}: {e}")
498
  # Enhanced fallback
499
  if len(audio) > 0:
500
  energy = np.sum(audio ** 2)
501
- # Use energy-based detection with better threshold
502
- probability = min(energy / 0.005, 1.0) # More sensitive threshold
503
- is_speech = energy > 0.005
504
  else:
505
  probability = 0.0
506
  is_speech = False
@@ -772,6 +740,7 @@ def create_realtime_plot(audio_data: np.ndarray, vad_results: List[VADResult],
772
  )
773
 
774
  if len(time_frames) > 0:
 
775
  fig.add_hline(
776
  y=threshold,
777
  line=dict(color='cyan', width=2, dash='dash'),
@@ -782,6 +751,8 @@ def create_realtime_plot(audio_data: np.ndarray, vad_results: List[VADResult],
782
  fig.add_hline(
783
  y=threshold,
784
  line=dict(color='cyan', width=2, dash='dash'),
 
 
785
  row=2, col=1, secondary_y=True
786
  )
787
 
 
362
  self.model = None
363
  self.feature_extractor = None
364
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
 
 
365
  self.load_model()
366
 
367
  def load_model(self):
 
371
  self.feature_extractor = ASTFeatureExtractor.from_pretrained(model_name)
372
  self.model = ASTForAudioClassification.from_pretrained(model_name)
373
  self.model.to(self.device)
 
 
374
  self.model.eval()
375
  print(f"✅ {self.model_name} loaded successfully")
376
  else:
 
391
  spectral_features = librosa.feature.spectral_rolloff(y=audio, sr=self.sample_rate)
392
  spectral_centroid = np.mean(librosa.feature.spectral_centroid(y=audio, sr=self.sample_rate))
393
  # Combine multiple features for better speech detection
394
+ probability = min((energy * 100 + spectral_centroid / 500) / 2, 1.0)
395
  else:
396
+ probability = min(energy * 50, 1.0)
397
+ is_speech = probability > 0.3
398
  else:
399
  probability = 0.0
400
  is_speech = False
 
404
  if len(audio.shape) > 1:
405
  audio = audio.mean(axis=1)
406
 
407
+ # Use longer context for AST - take from full audio if available
408
+ if full_audio is not None and len(full_audio) > self.sample_rate:
409
+ # Take 3-second window centered around current timestamp
410
+ center_pos = int(timestamp * self.sample_rate)
411
+ window_size = int(1.5 * self.sample_rate) # 1.5 seconds each side
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
 
413
+ start_pos = max(0, center_pos - window_size)
414
+ end_pos = min(len(full_audio), center_pos + window_size)
 
 
415
 
416
+ # Ensure we have at least 1 second
417
+ if end_pos - start_pos < self.sample_rate:
418
+ end_pos = min(len(full_audio), start_pos + self.sample_rate)
 
419
 
420
+ audio_for_ast = full_audio[start_pos:end_pos]
421
+ else:
422
+ audio_for_ast = audio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
 
424
+ # Ensure minimum length for AST
425
+ if len(audio_for_ast) < self.sample_rate:
426
+ audio_for_ast = np.pad(audio_for_ast, (0, self.sample_rate - len(audio_for_ast)), 'constant')
 
 
427
 
428
+ # Feature extraction with proper AST parameters
429
+ inputs = self.feature_extractor(
430
+ audio_for_ast,
431
+ sampling_rate=self.sample_rate,
432
+ return_tensors="pt",
433
+ max_length=1024, # Proper AST context
434
+ truncation=True
435
+ )
436
+
437
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
438
+
439
+ with torch.no_grad():
440
+ outputs = self.model(**inputs)
441
+ logits = outputs.logits
442
+ probs = torch.sigmoid(logits)
443
+
444
+ # Find speech-related classes
445
+ label2id = self.model.config.label2id
446
+ speech_indices = []
447
+ speech_keywords = ['speech', 'voice', 'talk', 'conversation', 'speaking']
448
+
449
+ for lbl, idx in label2id.items():
450
+ if any(word in lbl.lower() for word in speech_keywords):
451
+ speech_indices.append(idx)
452
+
453
+ if speech_indices:
454
+ speech_prob = probs[0, speech_indices].mean().item()
455
+ # Boost the probability if it's too low but there's clear audio content
456
+ if speech_prob < 0.1 and np.sum(audio_for_ast ** 2) > 0.001:
457
+ speech_prob = min(speech_prob * 5, 0.8) # Boost but cap at 0.8
458
+ else:
459
+ # Fallback to energy-based detection
460
+ energy = np.sum(audio_for_ast ** 2)
461
+ speech_prob = min(energy * 20, 1.0)
462
+
463
+ return VADResult(float(speech_prob), speech_prob > 0.4, self.model_name, time.time()-start_time, timestamp)
464
 
465
  except Exception as e:
466
  print(f"Error in {self.model_name}: {e}")
467
  # Enhanced fallback
468
  if len(audio) > 0:
469
  energy = np.sum(audio ** 2)
470
+ probability = min(energy * 30, 1.0) # More aggressive energy scaling
471
+ is_speech = energy > 0.002
 
472
  else:
473
  probability = 0.0
474
  is_speech = False
 
740
  )
741
 
742
  if len(time_frames) > 0:
743
+ # Add threshold lines to both panels
744
  fig.add_hline(
745
  y=threshold,
746
  line=dict(color='cyan', width=2, dash='dash'),
 
751
  fig.add_hline(
752
  y=threshold,
753
  line=dict(color='cyan', width=2, dash='dash'),
754
+ annotation_text=f'Threshold: {threshold:.2f}',
755
+ annotation_position="top right",
756
  row=2, col=1, secondary_y=True
757
  )
758