Spaces:
Sleeping
Sleeping
File size: 6,389 Bytes
7417a6a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
import random
from pathlib import Path
from typing import Dict, List, Tuple
import torch
from torch.utils.data import Dataset
import stempeg
import soundfile as sf
import math
import numpy as np
# ============================================================================
# Data Loader
# ============================================================================
def get_random_prompt(stem_name: str) -> str:
"""Get a random text prompt for a given stem."""
return random.choice(STEM_PROMPTS[stem_name])
# Text Prompt Templates
STEM_PROMPTS: Dict[str, List[str]] = {
"drums": ["drums", "drum kit", "percussion", "the drums"],
"bass": ["bass", "bass guitar", "the bass", "bass line"],
"other": ["other instruments", "accompaniment", "instruments"],
"vocals": ["vocals", "voice", "singing", "the vocals"],
}
PROMPT_TO_STEM: Dict[str, str] = {
prompt: stem
for stem, prompts in STEM_PROMPTS.items()
for prompt in prompts
}
STEM_NAME_TO_INDEX = {"drums": 0, "bass": 1, "other": 2, "vocals": 3}
class MusDBStemDataset(Dataset):
def __init__(
self,
root_dir: str,
segment_samples: int,
sample_rate: int = 44100,
channels: int = 2,
random_segments: bool = True,
augment: bool = True,
):
self.root_dir = Path(root_dir)
self.segment_samples = segment_samples
self.sample_rate = sample_rate
self.channels = channels
self.random_segments = random_segments
self.augment = augment
self.stem_names = ["drums", "bass", "other", "vocals"]
self.files = list(self.root_dir.glob("*.stem.mp4"))
if not self.files:
raise ValueError(f"No .stem.mp4 files found in {root_dir}")
# Compute number of examples
self.index_map = [] # (file_idx, stem_idx, segment_idx)
#self.sample_lengths = [0] * len(self.files) # total samples per file
for file_idx, file in enumerate(self.files):
info = stempeg.Info(str(file))
total_samples = info.duration(0) * info.sample_rate(0) # 0 - using mixture stream as reference
#self.sample_lengths[file_idx] = int(total_samples)
num_segments = math.ceil(total_samples / segment_samples)
# Build index map: for each stem, each segment
for stem_idx in range(len(self.stem_names)):
for seg in range(num_segments):
self.index_map.append((file_idx, stem_idx, seg))
print(f"Found {len(self.files)} tracks, total dataset items: {len(self.index_map)}")
def __len__(self) -> int:
return len(self.index_map)
def _load_stems(self, filepath: Path) -> np.ndarray:
"""Load all stems from a .stem.mp4 file."""
stems, rate = stempeg.read_stems(str(filepath))
# stems shape: (num_stems, samples, channels)
# [mix, drums, bass, other, vocals]
return stems
def _extract_random_segment(self, stems: np.ndarray) -> np.ndarray:
"""Extract the same random segment from all stems."""
total_samples = stems.shape[1] # stems: (num_stems, samples, channels)
if total_samples <= self.segment_samples:
# Pad if too short
pad_amount = self.segment_samples - total_samples
stems = np.pad(stems, ((0, 0), (0, pad_amount), (0, 0)), mode='constant')
else:
# Random start position (same for all stems)
if self.random_segments:
start = random.randint(0, total_samples - self.segment_samples)
else:
start = 0
stems = stems[:, start:start + self.segment_samples, :]
return stems
def _extract_segment(self, stems: np.ndarray, seg_idx: int) -> np.ndarray:
total_samples = stems.shape[1]
if self.random_segments:
# fallback to random segment extractor
return self._extract_random_segment(stems)
start = seg_idx * self.segment_samples
end = start + self.segment_samples
if end <= total_samples:
return stems[:, start:end, :]
else:
# Last segment may need padding
pad_amount = end - total_samples
seg = stems[:, start:, :]
seg = np.pad(seg, ((0, 0),(0, pad_amount), (0, 0)), mode="constant")
return seg
def _augment(self, mixture: np.ndarray, target: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""Apply data augmentation."""
if random.random() < 0.5:
gain = random.uniform(0.7, 1.3)
mixture = mixture * gain
target = target * gain
if random.random() < 0.3 and mixture.shape[-1] == 2:
mixture = mixture[:, ::-1].copy()
target = target[:, ::-1].copy()
return mixture, target
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor | str]:
file_idx, stem_idx, seg_idx = self.index_map[idx]
filepath = self.files[file_idx]
stems = self._load_stems(filepath)
# deterministic segment selection
stems = self._extract_segment(stems, seg_idx)
mixture = stems[0] # (T, C)
target = stems[stem_idx+1] # (T, C)
if self.augment:
mixture, target = self._augment(mixture, target)
# -> (C, T)
mixture = torch.from_numpy(mixture.T).float()
target = torch.from_numpy(target.T).float()
# ensure stereo
if mixture.shape[0] == 1:
mixture = mixture.repeat(2, 1)
target = target.repeat(2, 1)
prompt = get_random_prompt(self.stem_names[stem_idx])
return {
"mixture": mixture,
"target": target,
"prompt": prompt,
"stem_name": self.stem_names[stem_idx],
"file_idx": file_idx,
"segment_idx": seg_idx,
}
def collate_fn(batch: List[Dict]) -> Dict[str, torch.Tensor | List[str]]:
"""Custom collate function."""
return {
"mixture": torch.stack([item["mixture"] for item in batch]),
"target": torch.stack([item["target"] for item in batch]),
"prompt": [item["prompt"] for item in batch],
"stem_name": [item["stem_name"] for item in batch],
} |