File size: 5,173 Bytes
7417a6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
from typing import Dict, Tuple
import torch


# ============================================================================
# Loss Functions
# ============================================================================

def sdr_loss(estimated, target):
    """
    Compute negative SDR loss.
    Based on the definition from Vincent et al. 2006.
    """
    # Flatten to [batch, -1] to ensure compatible shapes
    est_flat = estimated.reshape(estimated.shape[0], -1)
    tgt_flat = target.reshape(target.shape[0], -1)

    # Compute SDR: 10 * log10(||target||^2 / ||target - estimated||^2)
    delta = 1e-8  # Small constant for numerical stability

    num = torch.sum(tgt_flat ** 2, dim=-1)
    den = torch.sum((tgt_flat - est_flat) ** 2, dim=-1)

    # Avoid division by zero
    sdr = 10 * torch.log10((num + delta) / (den + delta))

    # Clamp to reasonable range to avoid extreme values
    sdr = torch.clamp(sdr, min=-30, max=30)

    return -sdr.mean()  # Return negative for minimization


def sisdr_loss(estimated, target):
    """
    Compute negative SI-SDR (Scale-Invariant SDR) loss.
    This is more robust to scaling differences between estimate and target.
    """
    # Flatten to [batch, -1]
    est_flat = estimated.reshape(estimated.shape[0], -1)
    tgt_flat = target.reshape(target.shape[0], -1)

    # Zero-mean normalization (critical for SI-SDR)
    est_flat = est_flat - est_flat.mean(dim=-1, keepdim=True)
    tgt_flat = tgt_flat - tgt_flat.mean(dim=-1, keepdim=True)

    # SI-SDR calculation
    # Project estimate onto target: s_target = <s', s> / ||s||^2 * s
    delta = 1e-8

    dot = torch.sum(est_flat * tgt_flat, dim=-1, keepdim=True)
    s_target_norm_sq = torch.sum(tgt_flat ** 2, dim=-1, keepdim=True)

    # Scaled target
    s_target = (dot / (s_target_norm_sq + delta)) * tgt_flat

    # Noise is the orthogonal component
    e_noise = est_flat - s_target

    # SI-SDR = 10 * log10(||s_target||^2 / ||e_noise||^2)
    s_target_energy = torch.sum(s_target ** 2, dim=-1)
    e_noise_energy = torch.sum(e_noise ** 2, dim=-1)

    sisdr = 10 * torch.log10((s_target_energy + delta) / (e_noise_energy + delta))

    # Clamp to reasonable range
    sisdr = torch.clamp(sisdr, min=-30, max=30)

    return -sisdr.mean()  # Return negative for minimization


def new_sdr_metric(estimated: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    """
    Compute the SDR according to the MDX challenge definition (positive values).
    This is for evaluation/logging, not for loss.

    Args:
        estimated: (batch, channels, time)
        target: (batch, channels, time)

    Returns:
        SDR scores per batch item (batch,)
    """
    delta = 1e-8
    num = torch.sum(target ** 2, dim=(1, 2))
    den = torch.sum((target - estimated) ** 2, dim=(1, 2))
    scores = 10 * torch.log10((num + delta) / (den + delta))
    return scores


def combined_loss(
        estimated: torch.Tensor,
        target: torch.Tensor,
        sdr_weight: float = 0.9,
        sisdr_weight: float = 0.1
) -> Tuple[torch.Tensor, Dict[str, float]]:
    """
    Combined SDR and SI-SDR loss.

    Args:
        estimated: Estimated audio (batch, channels, time)
        target: Target audio (batch, channels, time)
        sdr_weight: Weight for SDR loss (default 0.9)
        sisdr_weight: Weight for SI-SDR loss (default 0.1)

    Returns:
        total_loss: Combined loss for backpropagation
        metrics: Dictionary of metrics for logging
    """
    sdr = sdr_loss(estimated, target)
    sisdr = sisdr_loss(estimated, target)

    total = sdr_weight * sdr + sisdr_weight * sisdr

    # For logging, also compute positive SDR metric
    with torch.no_grad():
        pos_sdr = new_sdr_metric(estimated, target).mean()

    metrics = {
        "loss/total": total.item(),
        "loss/sdr": sdr.item(),
        "loss/sisdr": sisdr.item(),
        "metrics/sdr": -sdr.item(),  # Positive SDR for logging
        "metrics/sisdr": -sisdr.item(),  # Positive SI-SDR for logging
        "metrics/new_sdr": pos_sdr.item(),  # MDX-style SDR
    }

    return total, metrics


def combined_L1_sdr_loss(
        estimated: torch.Tensor,
        target: torch.Tensor,
        sdr_weight: float = 1.0,
        l1_weight: float = 0.05
) -> Tuple[torch.Tensor, Dict[str, float]]:
    """
    Combined SDR and L1 loss.

    Args:
        estimated: Estimated audio (batch, channels, time)
        target: Target audio (batch, channels, time)
        sdr_weight: Weight for SDR loss (default 0.9)
        l1_weight: Weight for SI-SDR loss (default 0.1)
    Returns:
        total_loss: Combined loss for backpropagation
        metrics: Dictionary of metrics for logging
    """
    sdr = sdr_loss(estimated, target)
    sisdr = sisdr_loss(estimated, target)
    l1  = torch.nn.functional.l1_loss(estimated, target)

    total = sdr_weight * sdr + l1_weight * l1
    
    metrics = {
        "loss/total": total.item(),
        "loss/sdr": sdr.item(),
        "loss/sisdr": sisdr.item(),
        "metrics/sdr": -sdr.item(),  # Positive SDR for logging
        "metrics/sisdr": -sisdr.item(),  # Positive SI-SDR for logging
    }

    return total, metrics