Spaces:
Sleeping
Sleeping
File size: 5,173 Bytes
7417a6a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
from typing import Dict, Tuple
import torch
# ============================================================================
# Loss Functions
# ============================================================================
def sdr_loss(estimated, target):
"""
Compute negative SDR loss.
Based on the definition from Vincent et al. 2006.
"""
# Flatten to [batch, -1] to ensure compatible shapes
est_flat = estimated.reshape(estimated.shape[0], -1)
tgt_flat = target.reshape(target.shape[0], -1)
# Compute SDR: 10 * log10(||target||^2 / ||target - estimated||^2)
delta = 1e-8 # Small constant for numerical stability
num = torch.sum(tgt_flat ** 2, dim=-1)
den = torch.sum((tgt_flat - est_flat) ** 2, dim=-1)
# Avoid division by zero
sdr = 10 * torch.log10((num + delta) / (den + delta))
# Clamp to reasonable range to avoid extreme values
sdr = torch.clamp(sdr, min=-30, max=30)
return -sdr.mean() # Return negative for minimization
def sisdr_loss(estimated, target):
"""
Compute negative SI-SDR (Scale-Invariant SDR) loss.
This is more robust to scaling differences between estimate and target.
"""
# Flatten to [batch, -1]
est_flat = estimated.reshape(estimated.shape[0], -1)
tgt_flat = target.reshape(target.shape[0], -1)
# Zero-mean normalization (critical for SI-SDR)
est_flat = est_flat - est_flat.mean(dim=-1, keepdim=True)
tgt_flat = tgt_flat - tgt_flat.mean(dim=-1, keepdim=True)
# SI-SDR calculation
# Project estimate onto target: s_target = <s', s> / ||s||^2 * s
delta = 1e-8
dot = torch.sum(est_flat * tgt_flat, dim=-1, keepdim=True)
s_target_norm_sq = torch.sum(tgt_flat ** 2, dim=-1, keepdim=True)
# Scaled target
s_target = (dot / (s_target_norm_sq + delta)) * tgt_flat
# Noise is the orthogonal component
e_noise = est_flat - s_target
# SI-SDR = 10 * log10(||s_target||^2 / ||e_noise||^2)
s_target_energy = torch.sum(s_target ** 2, dim=-1)
e_noise_energy = torch.sum(e_noise ** 2, dim=-1)
sisdr = 10 * torch.log10((s_target_energy + delta) / (e_noise_energy + delta))
# Clamp to reasonable range
sisdr = torch.clamp(sisdr, min=-30, max=30)
return -sisdr.mean() # Return negative for minimization
def new_sdr_metric(estimated: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Compute the SDR according to the MDX challenge definition (positive values).
This is for evaluation/logging, not for loss.
Args:
estimated: (batch, channels, time)
target: (batch, channels, time)
Returns:
SDR scores per batch item (batch,)
"""
delta = 1e-8
num = torch.sum(target ** 2, dim=(1, 2))
den = torch.sum((target - estimated) ** 2, dim=(1, 2))
scores = 10 * torch.log10((num + delta) / (den + delta))
return scores
def combined_loss(
estimated: torch.Tensor,
target: torch.Tensor,
sdr_weight: float = 0.9,
sisdr_weight: float = 0.1
) -> Tuple[torch.Tensor, Dict[str, float]]:
"""
Combined SDR and SI-SDR loss.
Args:
estimated: Estimated audio (batch, channels, time)
target: Target audio (batch, channels, time)
sdr_weight: Weight for SDR loss (default 0.9)
sisdr_weight: Weight for SI-SDR loss (default 0.1)
Returns:
total_loss: Combined loss for backpropagation
metrics: Dictionary of metrics for logging
"""
sdr = sdr_loss(estimated, target)
sisdr = sisdr_loss(estimated, target)
total = sdr_weight * sdr + sisdr_weight * sisdr
# For logging, also compute positive SDR metric
with torch.no_grad():
pos_sdr = new_sdr_metric(estimated, target).mean()
metrics = {
"loss/total": total.item(),
"loss/sdr": sdr.item(),
"loss/sisdr": sisdr.item(),
"metrics/sdr": -sdr.item(), # Positive SDR for logging
"metrics/sisdr": -sisdr.item(), # Positive SI-SDR for logging
"metrics/new_sdr": pos_sdr.item(), # MDX-style SDR
}
return total, metrics
def combined_L1_sdr_loss(
estimated: torch.Tensor,
target: torch.Tensor,
sdr_weight: float = 1.0,
l1_weight: float = 0.05
) -> Tuple[torch.Tensor, Dict[str, float]]:
"""
Combined SDR and L1 loss.
Args:
estimated: Estimated audio (batch, channels, time)
target: Target audio (batch, channels, time)
sdr_weight: Weight for SDR loss (default 0.9)
l1_weight: Weight for SI-SDR loss (default 0.1)
Returns:
total_loss: Combined loss for backpropagation
metrics: Dictionary of metrics for logging
"""
sdr = sdr_loss(estimated, target)
sisdr = sisdr_loss(estimated, target)
l1 = torch.nn.functional.l1_loss(estimated, target)
total = sdr_weight * sdr + l1_weight * l1
metrics = {
"loss/total": total.item(),
"loss/sdr": sdr.item(),
"loss/sisdr": sisdr.item(),
"metrics/sdr": -sdr.item(), # Positive SDR for logging
"metrics/sisdr": -sisdr.item(), # Positive SI-SDR for logging
}
return total, metrics |