""" AudioTextHTDemucs v2 - Text-conditioned source separation. Changes from v1: - Custom trainable decoder that outputs 1 source (not 4) - HTDemucs encoder kept (frozen) - CLAP text encoder (frozen) - Cross-attention conditioning at bottleneck """ import torch import torch.nn as nn import torch.nn.functional as F from typing import List, Any from fractions import Fraction from einops import rearrange from demucs.htdemucs import HTDemucs from transformers import ClapModel, ClapTextModelWithProjection, RobertaTokenizerFast class TextCrossAttention(nn.Module): """Cross-attention: audio features attend to text embeddings.""" def __init__(self, feat_dim, text_dim, n_heads=8, dropout=0.0): super().__init__() self.q_proj = nn.Linear(feat_dim, feat_dim) self.k_proj = nn.Linear(text_dim, feat_dim) self.v_proj = nn.Linear(text_dim, feat_dim) self.attn = nn.MultiheadAttention(feat_dim, n_heads, batch_first=True, dropout=dropout) self.out_mlp = nn.Sequential( nn.Linear(feat_dim, feat_dim), nn.GELU(), nn.Linear(feat_dim, feat_dim), ) self.norm_q = nn.LayerNorm(feat_dim) self.norm_out = nn.LayerNorm(feat_dim) def forward_attend(self, queries, text_emb): q = self.norm_q(queries) if text_emb.dim() == 2: text_emb = text_emb.unsqueeze(1) k = self.k_proj(text_emb) v = self.v_proj(text_emb) q_proj = self.q_proj(q) attn_out, _ = self.attn(query=q_proj, key=k, value=v) out = queries + attn_out out = out + self.out_mlp(out) return self.norm_out(out) def forward(self, x, xt, text_emb): B, C, F, T = x.shape x_seq = rearrange(x, "b c f t -> b (f t) c") xt_seq = rearrange(xt, "b c t -> b t c") x_seq = self.forward_attend(x_seq, text_emb) xt_seq = self.forward_attend(xt_seq, text_emb) x = rearrange(x_seq, "b (f t) c -> b c f t", f=F, t=T) xt = rearrange(xt_seq, "b t c -> b c t") return x, xt class FreqDecoder(nn.Module): """Frequency-domain decoder: mirrors HTDemucs encoder structure but outputs 1 source.""" def __init__(self, channels: List[int], kernel_size: int = 8, stride: int = 4): """ channels: List of channel dims from bottleneck to output, e.g. [384, 192, 96, 48, 2] """ super().__init__() self.layers = nn.ModuleList() for i in range(len(channels) - 1): in_ch = channels[i] out_ch = channels[i + 1] is_last = (i == len(channels) - 2) self.layers.append(nn.Sequential( nn.ConvTranspose2d(in_ch, out_ch, kernel_size=(kernel_size, 1), stride=(stride, 1), padding=(kernel_size//4, 0)), nn.GroupNorm(1, out_ch) if not is_last else nn.Identity(), nn.GELU() if not is_last else nn.Identity(), )) def forward(self, x, skips: List[torch.Tensor], target_lengths: List[int]): """ x: (B, C, F, T) bottleneck features skips: encoder skip connections (reversed order) target_lengths: target frequency dimensions for each layer """ for i, layer in enumerate(self.layers): x = layer(x) # Match target size if i < len(target_lengths): target_f = target_lengths[i] if x.shape[2] != target_f: x = F.interpolate(x, size=(target_f, x.shape[3]), mode='bilinear', align_corners=False) # Add skip connection if available if i < len(skips): skip = skips[i] # Project skip to match channels if needed if skip.shape[1] != x.shape[1]: skip = skip[:, :x.shape[1]] # Simple channel truncation if skip.shape[2:] != x.shape[2:]: skip = F.interpolate(skip, size=x.shape[2:], mode='bilinear', align_corners=False) x = x + skip * 0.1 # Scaled residual return x class TimeDecoder(nn.Module): """Time-domain decoder: outputs 1 source waveform.""" def __init__(self, channels: List[int], kernel_size: int = 8, stride: int = 4): super().__init__() self.layers = nn.ModuleList() for i in range(len(channels) - 1): in_ch = channels[i] out_ch = channels[i + 1] is_last = (i == len(channels) - 2) self.layers.append(nn.Sequential( nn.ConvTranspose1d(in_ch, out_ch, kernel_size, stride, padding=kernel_size//4), nn.GroupNorm(1, out_ch) if not is_last else nn.Identity(), nn.GELU() if not is_last else nn.Identity(), )) def forward(self, x, skips: List[torch.Tensor], target_lengths: List[int]): for i, layer in enumerate(self.layers): x = layer(x) if i < len(target_lengths): target_t = target_lengths[i] if x.shape[2] != target_t: x = F.interpolate(x, size=target_t, mode='linear', align_corners=False) if i < len(skips): skip = skips[i] if skip.shape[1] != x.shape[1]: skip = skip[:, :x.shape[1]] if skip.shape[2] != x.shape[2]: skip = F.interpolate(skip, size=x.shape[2], mode='linear', align_corners=False) x = x + skip * 0.1 return x class AudioTextHTDemucs(nn.Module): """ Text-conditioned source separation. - HTDemucs encoder (frozen): extracts multi-scale audio features - CLAP (frozen): text embeddings - Cross-attention: conditions audio on text at bottleneck - Custom decoder (trainable): outputs single source """ def __init__( self, htdemucs_model: HTDemucs, clap_encoder: ClapModel | ClapTextModelWithProjection, clap_tokenizer: RobertaTokenizerFast, model_dim: int = 384, text_dim: int = 512, num_heads: int = 8, sample_rate: int = 44100, segment: float = 7.8, ): super().__init__() self.htdemucs = htdemucs_model self.clap = clap_encoder self.tokenizer = clap_tokenizer self.sample_rate = sample_rate self.segment = segment # Freeze HTDemucs encoder for param in self.htdemucs.parameters(): param.requires_grad = False # Freeze CLAP for param in self.clap.parameters(): param.requires_grad = False # Text cross-attention at bottleneck self.text_attn = TextCrossAttention(model_dim, text_dim, num_heads) # Custom decoders (trainable) - output 1 source with 2 channels (stereo) # Channel progression: 384 -> 192 -> 96 -> 48 -> 4 (will be reshaped to 2 channels) self.freq_decoder = FreqDecoder([384, 192, 96, 48, 4]) self.time_decoder = TimeDecoder([384, 192, 96, 48, 4]) # Final projection to stereo self.freq_out = nn.Conv2d(4, 2, 1) self.time_out = nn.Conv1d(4, 2, 1) def _encode(self, x, xt): """Run HTDemucs encoder, save skip connections.""" saved = [] saved_t = [] lengths = [] lengths_t = [] for idx, encode in enumerate(self.htdemucs.encoder): lengths.append(x.shape[-1]) inject = None if idx < len(self.htdemucs.tencoder): lengths_t.append(xt.shape[-1]) tenc = self.htdemucs.tencoder[idx] xt = tenc(xt) if not tenc.empty: saved_t.append(xt) else: inject = xt x = encode(x, inject) if idx == 0 and self.htdemucs.freq_emb is not None: frs = torch.arange(x.shape[-2], device=x.device) emb = self.htdemucs.freq_emb(frs).t()[None, :, :, None].expand_as(x) x = x + self.htdemucs.freq_emb_scale * emb saved.append(x) # Cross-transformer at bottleneck if self.htdemucs.crosstransformer: if self.htdemucs.bottom_channels: b, c, f, t = x.shape x = rearrange(x, "b c f t -> b c (f t)") x = self.htdemucs.channel_upsampler(x) x = rearrange(x, "b c (f t) -> b c f t", f=f) xt = self.htdemucs.channel_upsampler_t(xt) x, xt = self.htdemucs.crosstransformer(x, xt) if self.htdemucs.bottom_channels: x = rearrange(x, "b c f t -> b c (f t)") x = self.htdemucs.channel_downsampler(x) x = rearrange(x, "b c (f t) -> b c f t", f=f) xt = self.htdemucs.channel_downsampler_t(xt) return x, xt, saved, saved_t, lengths, lengths_t def _get_clap_embeddings(self, text: List[str], device): inputs = self.tokenizer(text, padding=True, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} if isinstance(self.clap, ClapModel): # Use get_text_features for ClapModel with torch.no_grad(): return self.clap.get_text_features(**inputs) else: # Use forward pass for ClapTextModelWithProjection with torch.no_grad(): return self.clap.forward(**inputs).text_embeds def forward(self, wav, text): """ wav: (B, 2, T) stereo mixture text: List[str] prompts Returns: (B, 2, T) separated stereo source """ device = wav.device B = wav.shape[0] original_length = wav.shape[-1] # Compute spectrogram (ensure all on same device) z = self.htdemucs._spec(wav).to(device) mag = self.htdemucs._magnitude(z).to(device) x = mag B, C, Fq, T_spec = x.shape # Normalize mean = x.mean(dim=(1, 2, 3), keepdim=True) std = x.std(dim=(1, 2, 3), keepdim=True) x = (x - mean) / (1e-5 + std) xt = wav meant = xt.mean(dim=(1, 2), keepdim=True) stdt = xt.std(dim=(1, 2), keepdim=True) xt = (xt - meant) / (1e-5 + stdt) # Encode (frozen) with torch.no_grad(): x_enc, xt_enc, saved, saved_t, lengths, lengths_t = self._encode(x, xt) # Text conditioning via cross-attention (trainable) text_emb = self._get_clap_embeddings(text, device) x_cond, xt_cond = self.text_attn(x_enc, xt_enc, text_emb) # Decode with custom decoder (trainable) # Reverse skips for decoder saved_rev = saved[::-1] saved_t_rev = saved_t[::-1] lengths_rev = lengths[::-1] lengths_t_rev = lengths_t[::-1] # Frequency decoder x_dec = self.freq_decoder(x_cond, saved_rev, lengths_rev) x_dec = self.freq_out(x_dec) # (B, 2, F, T) # Interpolate to match original spectrogram size x_dec = F.interpolate(x_dec, size=(Fq, T_spec), mode='bilinear', align_corners=False) # Apply as mask and invert spectrogram mask = torch.sigmoid(x_dec) # (B, 2, F, T) in [0, 1] # mag is (B, C, F, T) from htdemucs - take first 2 channels for stereo mag_stereo = mag[:, :2, :, :] # (B, 2, F, T) masked_spec = mag_stereo * mask # z is complex (B, C, F, T) - take stereo channels z_stereo = z[:, :2, :, :] # (B, 2, F, T) phase = z_stereo / (mag_stereo + 1e-8) # Complex phase masked_z = masked_spec * phase # Apply mask while preserving phase freq_wav = self.htdemucs._ispec(masked_z, original_length).to(device) # Time decoder xt_dec = self.time_decoder(xt_cond, saved_t_rev, lengths_t_rev) xt_dec = self.time_out(xt_dec) # (B, 2, T) # Interpolate to original length if xt_dec.shape[-1] != original_length: xt_dec = F.interpolate(xt_dec, size=original_length, mode='linear', align_corners=False) # Denormalize time output xt_dec = xt_dec * stdt + meant # Combine frequency and time branches output = freq_wav + xt_dec return output if __name__ == "__main__": from demucs import pretrained htdemucs = pretrained.get_model('htdemucs').models[0] clap = ClapModel.from_pretrained("laion/clap-htsat-unfused") tokenizer = __import__('transformers').AutoTokenizer.from_pretrained("laion/clap-htsat-unfused") model = AudioTextHTDemucs(htdemucs, clap, tokenizer) # Count params total = sum(p.numel() for p in model.parameters()) trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"Total params: {total:,}") print(f"Trainable params: {trainable:,}") # Test forward wav = torch.randn(2, 2, 44100 * 3) prompts = ["drums", "bass"] out = model(wav, prompts) print(f"Input: {wav.shape} -> Output: {out.shape}")