from typing import Dict, Tuple import torch # ============================================================================ # Loss Functions # ============================================================================ def sdr_loss(estimated, target): """ Compute negative SDR loss. Based on the definition from Vincent et al. 2006. """ # Flatten to [batch, -1] to ensure compatible shapes est_flat = estimated.reshape(estimated.shape[0], -1) tgt_flat = target.reshape(target.shape[0], -1) # Compute SDR: 10 * log10(||target||^2 / ||target - estimated||^2) delta = 1e-8 # Small constant for numerical stability num = torch.sum(tgt_flat ** 2, dim=-1) den = torch.sum((tgt_flat - est_flat) ** 2, dim=-1) # Avoid division by zero sdr = 10 * torch.log10((num + delta) / (den + delta)) # Clamp to reasonable range to avoid extreme values sdr = torch.clamp(sdr, min=-30, max=30) return -sdr.mean() # Return negative for minimization def sisdr_loss(estimated, target): """ Compute negative SI-SDR (Scale-Invariant SDR) loss. This is more robust to scaling differences between estimate and target. """ # Flatten to [batch, -1] est_flat = estimated.reshape(estimated.shape[0], -1) tgt_flat = target.reshape(target.shape[0], -1) # Zero-mean normalization (critical for SI-SDR) est_flat = est_flat - est_flat.mean(dim=-1, keepdim=True) tgt_flat = tgt_flat - tgt_flat.mean(dim=-1, keepdim=True) # SI-SDR calculation # Project estimate onto target: s_target = / ||s||^2 * s delta = 1e-8 dot = torch.sum(est_flat * tgt_flat, dim=-1, keepdim=True) s_target_norm_sq = torch.sum(tgt_flat ** 2, dim=-1, keepdim=True) # Scaled target s_target = (dot / (s_target_norm_sq + delta)) * tgt_flat # Noise is the orthogonal component e_noise = est_flat - s_target # SI-SDR = 10 * log10(||s_target||^2 / ||e_noise||^2) s_target_energy = torch.sum(s_target ** 2, dim=-1) e_noise_energy = torch.sum(e_noise ** 2, dim=-1) sisdr = 10 * torch.log10((s_target_energy + delta) / (e_noise_energy + delta)) # Clamp to reasonable range sisdr = torch.clamp(sisdr, min=-30, max=30) return -sisdr.mean() # Return negative for minimization def new_sdr_metric(estimated: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Compute the SDR according to the MDX challenge definition (positive values). This is for evaluation/logging, not for loss. Args: estimated: (batch, channels, time) target: (batch, channels, time) Returns: SDR scores per batch item (batch,) """ delta = 1e-8 num = torch.sum(target ** 2, dim=(1, 2)) den = torch.sum((target - estimated) ** 2, dim=(1, 2)) scores = 10 * torch.log10((num + delta) / (den + delta)) return scores def combined_loss( estimated: torch.Tensor, target: torch.Tensor, sdr_weight: float = 0.9, sisdr_weight: float = 0.1 ) -> Tuple[torch.Tensor, Dict[str, float]]: """ Combined SDR and SI-SDR loss. Args: estimated: Estimated audio (batch, channels, time) target: Target audio (batch, channels, time) sdr_weight: Weight for SDR loss (default 0.9) sisdr_weight: Weight for SI-SDR loss (default 0.1) Returns: total_loss: Combined loss for backpropagation metrics: Dictionary of metrics for logging """ sdr = sdr_loss(estimated, target) sisdr = sisdr_loss(estimated, target) total = sdr_weight * sdr + sisdr_weight * sisdr # For logging, also compute positive SDR metric with torch.no_grad(): pos_sdr = new_sdr_metric(estimated, target).mean() metrics = { "loss/total": total.item(), "loss/sdr": sdr.item(), "loss/sisdr": sisdr.item(), "metrics/sdr": -sdr.item(), # Positive SDR for logging "metrics/sisdr": -sisdr.item(), # Positive SI-SDR for logging "metrics/new_sdr": pos_sdr.item(), # MDX-style SDR } return total, metrics def combined_L1_sdr_loss( estimated: torch.Tensor, target: torch.Tensor, sdr_weight: float = 1.0, l1_weight: float = 0.05 ) -> Tuple[torch.Tensor, Dict[str, float]]: """ Combined SDR and L1 loss. Args: estimated: Estimated audio (batch, channels, time) target: Target audio (batch, channels, time) sdr_weight: Weight for SDR loss (default 0.9) l1_weight: Weight for SI-SDR loss (default 0.1) Returns: total_loss: Combined loss for backpropagation metrics: Dictionary of metrics for logging """ sdr = sdr_loss(estimated, target) sisdr = sisdr_loss(estimated, target) l1 = torch.nn.functional.l1_loss(estimated, target) total = sdr_weight * sdr + l1_weight * l1 metrics = { "loss/total": total.item(), "loss/sdr": sdr.item(), "loss/sisdr": sisdr.item(), "metrics/sdr": -sdr.item(), # Positive SDR for logging "metrics/sisdr": -sisdr.item(), # Positive SI-SDR for logging } return total, metrics