Spaces:
Configuration error
Configuration error
| # Copyright (c) 2024 Amphion. | |
| # | |
| # This code is modified from https://github.com/imdanboy/jets/blob/main/espnet2/gan_tts/jets/loss.py | |
| # Licensed under Apache License 2.0 | |
| from typing import List, Optional, Tuple, Union | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import librosa | |
| from models.vocoders.gan.discriminator.mpd import MultiScaleMultiPeriodDiscriminator | |
| from models.tts.jets.alignments import make_non_pad_mask, make_pad_mask | |
| class GeneratorAdversarialLoss(torch.nn.Module): | |
| """Generator adversarial loss module.""" | |
| def __init__(self): | |
| super().__init__() | |
| def forward(self, outputs) -> torch.Tensor: | |
| if isinstance(outputs, (tuple, list)): | |
| adv_loss = 0.0 | |
| for i, outputs_ in enumerate(outputs): | |
| if isinstance(outputs_, (tuple, list)): | |
| # NOTE(kan-bayashi): case including feature maps | |
| outputs_ = outputs_[-1] | |
| adv_loss += F.mse_loss(outputs_, outputs_.new_ones(outputs_.size())) | |
| else: | |
| adv_loss = F.mse_loss(outputs, outputs.new_ones(outputs.size())) | |
| return adv_loss | |
| class FeatureMatchLoss(torch.nn.Module): | |
| """Feature matching loss module.""" | |
| def __init__( | |
| self, | |
| average_by_layers: bool = False, | |
| average_by_discriminators: bool = False, | |
| include_final_outputs: bool = True, | |
| ): | |
| """Initialize FeatureMatchLoss module. | |
| Args: | |
| average_by_layers (bool): Whether to average the loss by the number | |
| of layers. | |
| average_by_discriminators (bool): Whether to average the loss by | |
| the number of discriminators. | |
| include_final_outputs (bool): Whether to include the final output of | |
| each discriminator for loss calculation. | |
| """ | |
| super().__init__() | |
| self.average_by_layers = average_by_layers | |
| self.average_by_discriminators = average_by_discriminators | |
| self.include_final_outputs = include_final_outputs | |
| def forward( | |
| self, | |
| feats_hat: Union[List[List[torch.Tensor]], List[torch.Tensor]], | |
| feats: Union[List[List[torch.Tensor]], List[torch.Tensor]], | |
| ) -> torch.Tensor: | |
| """Calculate feature matching loss. | |
| Args: | |
| feats_hat (Union[List[List[Tensor]], List[Tensor]]): List of list of | |
| discriminator outputs or list of discriminator outputs calcuated | |
| from generator's outputs. | |
| feats (Union[List[List[Tensor]], List[Tensor]]): List of list of | |
| discriminator outputs or list of discriminator outputs calcuated | |
| from groundtruth.. | |
| Returns: | |
| Tensor: Feature matching loss value. | |
| """ | |
| feat_match_loss = 0.0 | |
| for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)): | |
| feat_match_loss_ = 0.0 | |
| if not self.include_final_outputs: | |
| feats_hat_ = feats_hat_[:-1] | |
| feats_ = feats_[:-1] | |
| for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)): | |
| feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach()) | |
| if self.average_by_layers: | |
| feat_match_loss_ /= j + 1 | |
| feat_match_loss += feat_match_loss_ | |
| if self.average_by_discriminators: | |
| feat_match_loss /= i + 1 | |
| return feat_match_loss | |
| class DurationPredictorLoss(torch.nn.Module): | |
| """Loss function module for duration predictor. | |
| The loss value is Calculated in log domain to make it Gaussian. | |
| """ | |
| def __init__(self, offset=1.0, reduction="mean"): | |
| """Initilize duration predictor loss module. | |
| Args: | |
| offset (float, optional): Offset value to avoid nan in log domain. | |
| reduction (str): Reduction type in loss calculation. | |
| """ | |
| super().__init__() | |
| self.criterion = torch.nn.MSELoss(reduction=reduction) | |
| self.offset = offset | |
| def forward(self, outputs, targets): | |
| targets = torch.log(targets.float() + self.offset) | |
| loss = self.criterion(outputs, targets) | |
| return loss | |
| class VarianceLoss(torch.nn.Module): | |
| def __init__(self): | |
| """Initialize JETS variance loss module.""" | |
| super().__init__() | |
| # define criterions | |
| reduction = "mean" | |
| self.mse_criterion = torch.nn.MSELoss(reduction=reduction) | |
| self.duration_criterion = DurationPredictorLoss(reduction=reduction) | |
| def forward( | |
| self, | |
| d_outs: torch.Tensor, | |
| ds: torch.Tensor, | |
| p_outs: torch.Tensor, | |
| ps: torch.Tensor, | |
| e_outs: torch.Tensor, | |
| es: torch.Tensor, | |
| ilens: torch.Tensor, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """Calculate forward propagation. | |
| Args: | |
| d_outs (LongTensor): Batch of outputs of duration predictor (B, T_text). | |
| ds (LongTensor): Batch of durations (B, T_text). | |
| p_outs (Tensor): Batch of outputs of pitch predictor (B, T_text, 1). | |
| ps (Tensor): Batch of target token-averaged pitch (B, T_text, 1). | |
| e_outs (Tensor): Batch of outputs of energy predictor (B, T_text, 1). | |
| es (Tensor): Batch of target token-averaged energy (B, T_text, 1). | |
| ilens (LongTensor): Batch of the lengths of each input (B,). | |
| Returns: | |
| Tensor: Duration predictor loss value. | |
| Tensor: Pitch predictor loss value. | |
| Tensor: Energy predictor loss value. | |
| """ | |
| # apply mask to remove padded part | |
| duration_masks = make_non_pad_mask(ilens).to(ds.device) | |
| d_outs = d_outs.masked_select(duration_masks) | |
| ds = ds.masked_select(duration_masks) | |
| pitch_masks = make_non_pad_mask(ilens).to(ds.device) | |
| pitch_masks_ = make_non_pad_mask(ilens).unsqueeze(-1).to(ds.device) | |
| p_outs = p_outs.masked_select(pitch_masks) | |
| e_outs = e_outs.masked_select(pitch_masks) | |
| ps = ps.masked_select(pitch_masks_) | |
| es = es.masked_select(pitch_masks_) | |
| # calculate loss | |
| duration_loss = self.duration_criterion(d_outs, ds) | |
| pitch_loss = self.mse_criterion(p_outs, ps) | |
| energy_loss = self.mse_criterion(e_outs, es) | |
| return duration_loss, pitch_loss, energy_loss | |
| class ForwardSumLoss(torch.nn.Module): | |
| """Forwardsum loss described at https://openreview.net/forum?id=0NQwnnwAORi""" | |
| def __init__(self): | |
| """Initialize forwardsum loss module.""" | |
| super().__init__() | |
| def forward( | |
| self, | |
| log_p_attn: torch.Tensor, | |
| ilens: torch.Tensor, | |
| olens: torch.Tensor, | |
| blank_prob: float = np.e**-1, | |
| ) -> torch.Tensor: | |
| """Calculate forward propagation. | |
| Args: | |
| log_p_attn (Tensor): Batch of log probability of attention matrix | |
| (B, T_feats, T_text). | |
| ilens (Tensor): Batch of the lengths of each input (B,). | |
| olens (Tensor): Batch of the lengths of each target (B,). | |
| blank_prob (float): Blank symbol probability. | |
| Returns: | |
| Tensor: forwardsum loss value. | |
| """ | |
| B = log_p_attn.size(0) | |
| # a row must be added to the attention matrix to account for | |
| # blank token of CTC loss | |
| # (B,T_feats,T_text+1) | |
| log_p_attn_pd = F.pad(log_p_attn, (1, 0, 0, 0, 0, 0), value=np.log(blank_prob)) | |
| loss = 0 | |
| for bidx in range(B): | |
| # construct target sequnece. | |
| # Every text token is mapped to a unique sequnece number. | |
| target_seq = torch.arange(1, ilens[bidx] + 1).unsqueeze(0) | |
| cur_log_p_attn_pd = log_p_attn_pd[ | |
| bidx, : olens[bidx], : ilens[bidx] + 1 | |
| ].unsqueeze( | |
| 1 | |
| ) # (T_feats,1,T_text+1) | |
| cur_log_p_attn_pd = F.log_softmax(cur_log_p_attn_pd, dim=-1) | |
| loss += F.ctc_loss( | |
| log_probs=cur_log_p_attn_pd, | |
| targets=target_seq, | |
| input_lengths=olens[bidx : bidx + 1], | |
| target_lengths=ilens[bidx : bidx + 1], | |
| zero_infinity=True, | |
| ) | |
| loss = loss / B | |
| return loss | |
| class MelSpectrogramLoss(torch.nn.Module): | |
| """Mel-spectrogram loss.""" | |
| def __init__( | |
| self, | |
| fs: int = 22050, | |
| n_fft: int = 1024, | |
| hop_length: int = 256, | |
| win_length: Optional[int] = None, | |
| window: str = "hann", | |
| n_mels: int = 80, | |
| fmin: Optional[int] = 0, | |
| fmax: Optional[int] = None, | |
| center: bool = True, | |
| normalized: bool = False, | |
| onesided: bool = True, | |
| htk: bool = False, | |
| ): | |
| """Initialize Mel-spectrogram loss. | |
| Args: | |
| fs (int): Sampling rate. | |
| n_fft (int): FFT points. | |
| hop_length (int): Hop length. | |
| win_length (Optional[int]): Window length. | |
| window (str): Window type. | |
| n_mels (int): Number of Mel basis. | |
| fmin (Optional[int]): Minimum frequency for Mel. | |
| fmax (Optional[int]): Maximum frequency for Mel. | |
| center (bool): Whether to use center window. | |
| normalized (bool): Whether to use normalized one. | |
| onesided (bool): Whether to use oneseded one. | |
| """ | |
| super().__init__() | |
| self.fs = fs | |
| self.n_fft = n_fft | |
| self.hop_length = hop_length | |
| self.win_length = n_fft | |
| self.window = window | |
| self.n_mels = n_mels | |
| self.fmin = 0 if fmin is None else fmin | |
| self.fmax = fs / 2 if fmax is None else fmax | |
| self.center = center | |
| self.normalized = normalized | |
| self.onesided = onesided | |
| self.htk = htk | |
| def logmel(self, feat, ilens): | |
| mel_options = dict( | |
| sr=self.fs, | |
| n_fft=self.n_fft, | |
| n_mels=self.n_mels, | |
| fmin=self.fmin, | |
| fmax=self.fmax, | |
| htk=self.htk, | |
| ) | |
| melmat = librosa.filters.mel(**mel_options) | |
| melmat = torch.from_numpy(melmat.T).float().to(feat.device) | |
| mel_feat = torch.matmul(feat, melmat) | |
| mel_feat = torch.clamp(mel_feat, min=1e-10) | |
| logmel_feat = mel_feat.log10() | |
| # Zero padding | |
| if ilens is not None: | |
| logmel_feat = logmel_feat.masked_fill( | |
| make_pad_mask(ilens, logmel_feat, 1), 0.0 | |
| ) | |
| else: | |
| ilens = feat.new_full( | |
| [feat.size(0)], fill_value=feat.size(1), dtype=torch.long | |
| ) | |
| return logmel_feat | |
| def wav_to_mel(self, input, input_lengths=None): | |
| if self.window is not None: | |
| window_func = getattr(torch, f"{self.window}_window") | |
| window = window_func( | |
| self.win_length, dtype=input.dtype, device=input.device | |
| ) | |
| stft_kwargs = dict( | |
| n_fft=self.n_fft, | |
| win_length=self.win_length, | |
| hop_length=self.hop_length, | |
| center=self.center, | |
| window=window, | |
| normalized=self.normalized, | |
| onesided=self.onesided, | |
| return_complex=True, | |
| ) | |
| bs = input.size(0) | |
| if input.dim() == 3: | |
| multi_channel = True | |
| # input: (Batch, Nsample, Channels) -> (Batch * Channels, Nsample) | |
| input = input.transpose(1, 2).reshape(-1, input.size(1)) | |
| else: | |
| multi_channel = False | |
| input_stft = torch.stft(input, **stft_kwargs) | |
| input_stft = torch.view_as_real(input_stft) | |
| input_stft = input_stft.transpose(1, 2) | |
| if multi_channel: | |
| input_stft = input_stft.view( | |
| bs, -1, input_stft.size(1), input_stft.size(2), 2 | |
| ).transpose(1, 2) | |
| if input_lengths is not None: | |
| if self.center: | |
| pad = self.n_fft // 2 | |
| input_lengths = input_lengths + 2 * pad | |
| feats_lens = (input_lengths - self.n_fft) // self.hop_length + 1 | |
| input_stft.masked_fill_(make_pad_mask(feats_lens, input_stft, 1), 0.0) | |
| else: | |
| feats_lens = None | |
| input_power = input_stft[..., 0] ** 2 + input_stft[..., 1] ** 2 | |
| input_amp = torch.sqrt(torch.clamp(input_power, min=1.0e-10)) | |
| input_feats = self.logmel(input_amp, feats_lens) | |
| return input_feats, feats_lens | |
| def forward( | |
| self, | |
| y_hat: torch.Tensor, | |
| y: torch.Tensor, | |
| ) -> torch.Tensor: | |
| mel_hat, _ = self.wav_to_mel(y_hat.squeeze(1)) | |
| mel, _ = self.wav_to_mel(y.squeeze(1)) | |
| mel_loss = F.l1_loss(mel_hat, mel) | |
| return mel_loss | |
| class GeneratorLoss(nn.Module): | |
| """The total loss of the generator""" | |
| def __init__(self, cfg): | |
| super().__init__() | |
| self.cfg = cfg | |
| self.mel_loss = MelSpectrogramLoss() | |
| self.generator_adv_loss = GeneratorAdversarialLoss() | |
| self.feat_match_loss = FeatureMatchLoss() | |
| self.var_loss = VarianceLoss() | |
| self.forwardsum_loss = ForwardSumLoss() | |
| self.lambda_adv = 1.0 | |
| self.lambda_mel = 45.0 | |
| self.lambda_feat_match = 2.0 | |
| self.lambda_var = 1.0 | |
| self.lambda_align = 2.0 | |
| def forward(self, outputs_g, outputs_d, speech_): | |
| loss_g = {} | |
| # parse generator output | |
| ( | |
| speech_hat_, | |
| bin_loss, | |
| log_p_attn, | |
| start_idxs, | |
| d_outs, | |
| ds, | |
| p_outs, | |
| ps, | |
| e_outs, | |
| es, | |
| text_lengths, | |
| feats_lengths, | |
| ) = outputs_g | |
| # parse discriminator output | |
| (p_hat, p) = outputs_d | |
| # calculate losses | |
| mel_loss = self.mel_loss(speech_hat_, speech_) | |
| adv_loss = self.generator_adv_loss(p_hat) | |
| feat_match_loss = self.feat_match_loss(p_hat, p) | |
| dur_loss, pitch_loss, energy_loss = self.var_loss( | |
| d_outs, ds, p_outs, ps, e_outs, es, text_lengths | |
| ) | |
| forwardsum_loss = self.forwardsum_loss(log_p_attn, text_lengths, feats_lengths) | |
| # calculate total loss | |
| mel_loss = mel_loss * self.lambda_mel | |
| loss_g["mel_loss"] = mel_loss | |
| adv_loss = adv_loss * self.lambda_adv | |
| loss_g["adv_loss"] = adv_loss | |
| feat_match_loss = feat_match_loss * self.lambda_feat_match | |
| loss_g["feat_match_loss"] = feat_match_loss | |
| g_loss = mel_loss + adv_loss + feat_match_loss | |
| loss_g["g_loss"] = g_loss | |
| var_loss = (dur_loss + pitch_loss + energy_loss) * self.lambda_var | |
| loss_g["var_loss"] = var_loss | |
| align_loss = (forwardsum_loss + bin_loss) * self.lambda_align | |
| loss_g["align_loss"] = align_loss | |
| g_total_loss = g_loss + var_loss + align_loss | |
| loss_g["g_total_loss"] = g_total_loss | |
| return loss_g | |
| class DiscriminatorAdversarialLoss(torch.nn.Module): | |
| """Discriminator adversarial loss module.""" | |
| def __init__( | |
| self, | |
| average_by_discriminators: bool = True, | |
| loss_type: str = "mse", | |
| ): | |
| """Initialize DiscriminatorAversarialLoss module. | |
| Args: | |
| average_by_discriminators (bool): Whether to average the loss by | |
| the number of discriminators. | |
| loss_type (str): Loss type, "mse" or "hinge". | |
| """ | |
| super().__init__() | |
| self.average_by_discriminators = average_by_discriminators | |
| assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported." | |
| if loss_type == "mse": | |
| self.fake_criterion = self._mse_fake_loss | |
| self.real_criterion = self._mse_real_loss | |
| else: | |
| self.fake_criterion = self._hinge_fake_loss | |
| self.real_criterion = self._hinge_real_loss | |
| def forward( | |
| self, | |
| outputs_hat: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor], | |
| outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor], | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Calcualate discriminator adversarial loss. | |
| Args: | |
| outputs_hat (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator | |
| outputs, list of discriminator outputs, or list of list of discriminator | |
| outputs calculated from generator. | |
| outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator | |
| outputs, list of discriminator outputs, or list of list of discriminator | |
| outputs calculated from groundtruth. | |
| Returns: | |
| Tensor: Discriminator real loss value. | |
| Tensor: Discriminator fake loss value. | |
| """ | |
| if isinstance(outputs, (tuple, list)): | |
| real_loss = 0.0 | |
| fake_loss = 0.0 | |
| for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)): | |
| if isinstance(outputs_hat_, (tuple, list)): | |
| # NOTE(kan-bayashi): case including feature maps | |
| outputs_hat_ = outputs_hat_[-1] | |
| outputs_ = outputs_[-1] | |
| real_loss += self.real_criterion(outputs_) | |
| fake_loss += self.fake_criterion(outputs_hat_) | |
| if self.average_by_discriminators: | |
| fake_loss /= i + 1 | |
| real_loss /= i + 1 | |
| else: | |
| real_loss = self.real_criterion(outputs) | |
| fake_loss = self.fake_criterion(outputs_hat) | |
| return real_loss, fake_loss | |
| def _mse_real_loss(self, x: torch.Tensor) -> torch.Tensor: | |
| return F.mse_loss(x, x.new_ones(x.size())) | |
| def _mse_fake_loss(self, x: torch.Tensor) -> torch.Tensor: | |
| return F.mse_loss(x, x.new_zeros(x.size())) | |
| def _hinge_real_loss(self, x: torch.Tensor) -> torch.Tensor: | |
| return -torch.mean(torch.min(x - 1, x.new_zeros(x.size()))) | |
| def _hinge_fake_loss(self, x: torch.Tensor) -> torch.Tensor: | |
| return -torch.mean(torch.min(-x - 1, x.new_zeros(x.size()))) | |
| class DiscriminatorLoss(torch.nn.Module): | |
| """The total loss of the discriminator""" | |
| def __init__(self, cfg): | |
| super(DiscriminatorLoss, self).__init__() | |
| self.cfg = cfg | |
| self.discriminator = MultiScaleMultiPeriodDiscriminator() | |
| self.discriminator_adv_loss = DiscriminatorAdversarialLoss() | |
| def forward(self, speech_real, speech_generated): | |
| loss_d = {} | |
| real_loss, fake_loss = self.discriminator_adv_loss( | |
| speech_generated, speech_real | |
| ) | |
| loss_d["loss_disc_all"] = real_loss + fake_loss | |
| return loss_d | |