Spaces:
Sleeping
Sleeping
| import random | |
| from pathlib import Path | |
| from typing import Dict, List, Tuple | |
| import torch | |
| from torch.utils.data import Dataset | |
| import stempeg | |
| import soundfile as sf | |
| import math | |
| import numpy as np | |
| # ============================================================================ | |
| # Data Loader | |
| # ============================================================================ | |
| def get_random_prompt(stem_name: str) -> str: | |
| """Get a random text prompt for a given stem.""" | |
| return random.choice(STEM_PROMPTS[stem_name]) | |
| # Text Prompt Templates | |
| STEM_PROMPTS: Dict[str, List[str]] = { | |
| "drums": ["drums", "drum kit", "percussion", "the drums"], | |
| "bass": ["bass", "bass guitar", "the bass", "bass line"], | |
| "other": ["other instruments", "accompaniment", "instruments"], | |
| "vocals": ["vocals", "voice", "singing", "the vocals"], | |
| } | |
| PROMPT_TO_STEM: Dict[str, str] = { | |
| prompt: stem | |
| for stem, prompts in STEM_PROMPTS.items() | |
| for prompt in prompts | |
| } | |
| STEM_NAME_TO_INDEX = {"drums": 0, "bass": 1, "other": 2, "vocals": 3} | |
| class MusDBStemDataset(Dataset): | |
| def __init__( | |
| self, | |
| root_dir: str, | |
| segment_samples: int, | |
| sample_rate: int = 44100, | |
| channels: int = 2, | |
| random_segments: bool = True, | |
| augment: bool = True, | |
| ): | |
| self.root_dir = Path(root_dir) | |
| self.segment_samples = segment_samples | |
| self.sample_rate = sample_rate | |
| self.channels = channels | |
| self.random_segments = random_segments | |
| self.augment = augment | |
| self.stem_names = ["drums", "bass", "other", "vocals"] | |
| self.files = list(self.root_dir.glob("*.stem.mp4")) | |
| if not self.files: | |
| raise ValueError(f"No .stem.mp4 files found in {root_dir}") | |
| # Compute number of examples | |
| self.index_map = [] # (file_idx, stem_idx, segment_idx) | |
| #self.sample_lengths = [0] * len(self.files) # total samples per file | |
| for file_idx, file in enumerate(self.files): | |
| info = stempeg.Info(str(file)) | |
| total_samples = info.duration(0) * info.sample_rate(0) # 0 - using mixture stream as reference | |
| #self.sample_lengths[file_idx] = int(total_samples) | |
| num_segments = math.ceil(total_samples / segment_samples) | |
| # Build index map: for each stem, each segment | |
| for stem_idx in range(len(self.stem_names)): | |
| for seg in range(num_segments): | |
| self.index_map.append((file_idx, stem_idx, seg)) | |
| print(f"Found {len(self.files)} tracks, total dataset items: {len(self.index_map)}") | |
| def __len__(self) -> int: | |
| return len(self.index_map) | |
| def _load_stems(self, filepath: Path) -> np.ndarray: | |
| """Load all stems from a .stem.mp4 file.""" | |
| stems, rate = stempeg.read_stems(str(filepath)) | |
| # stems shape: (num_stems, samples, channels) | |
| # [mix, drums, bass, other, vocals] | |
| return stems | |
| def _extract_random_segment(self, stems: np.ndarray) -> np.ndarray: | |
| """Extract the same random segment from all stems.""" | |
| total_samples = stems.shape[1] # stems: (num_stems, samples, channels) | |
| if total_samples <= self.segment_samples: | |
| # Pad if too short | |
| pad_amount = self.segment_samples - total_samples | |
| stems = np.pad(stems, ((0, 0), (0, pad_amount), (0, 0)), mode='constant') | |
| else: | |
| # Random start position (same for all stems) | |
| if self.random_segments: | |
| start = random.randint(0, total_samples - self.segment_samples) | |
| else: | |
| start = 0 | |
| stems = stems[:, start:start + self.segment_samples, :] | |
| return stems | |
| def _extract_segment(self, stems: np.ndarray, seg_idx: int) -> np.ndarray: | |
| total_samples = stems.shape[1] | |
| if self.random_segments: | |
| # fallback to random segment extractor | |
| return self._extract_random_segment(stems) | |
| start = seg_idx * self.segment_samples | |
| end = start + self.segment_samples | |
| if end <= total_samples: | |
| return stems[:, start:end, :] | |
| else: | |
| # Last segment may need padding | |
| pad_amount = end - total_samples | |
| seg = stems[:, start:, :] | |
| seg = np.pad(seg, ((0, 0),(0, pad_amount), (0, 0)), mode="constant") | |
| return seg | |
| def _augment(self, mixture: np.ndarray, target: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: | |
| """Apply data augmentation.""" | |
| if random.random() < 0.5: | |
| gain = random.uniform(0.7, 1.3) | |
| mixture = mixture * gain | |
| target = target * gain | |
| if random.random() < 0.3 and mixture.shape[-1] == 2: | |
| mixture = mixture[:, ::-1].copy() | |
| target = target[:, ::-1].copy() | |
| return mixture, target | |
| def __getitem__(self, idx: int) -> Dict[str, torch.Tensor | str]: | |
| file_idx, stem_idx, seg_idx = self.index_map[idx] | |
| filepath = self.files[file_idx] | |
| stems = self._load_stems(filepath) | |
| # deterministic segment selection | |
| stems = self._extract_segment(stems, seg_idx) | |
| mixture = stems[0] # (T, C) | |
| target = stems[stem_idx+1] # (T, C) | |
| if self.augment: | |
| mixture, target = self._augment(mixture, target) | |
| # -> (C, T) | |
| mixture = torch.from_numpy(mixture.T).float() | |
| target = torch.from_numpy(target.T).float() | |
| # ensure stereo | |
| if mixture.shape[0] == 1: | |
| mixture = mixture.repeat(2, 1) | |
| target = target.repeat(2, 1) | |
| prompt = get_random_prompt(self.stem_names[stem_idx]) | |
| return { | |
| "mixture": mixture, | |
| "target": target, | |
| "prompt": prompt, | |
| "stem_name": self.stem_names[stem_idx], | |
| "file_idx": file_idx, | |
| "segment_idx": seg_idx, | |
| } | |
| def collate_fn(batch: List[Dict]) -> Dict[str, torch.Tensor | List[str]]: | |
| """Custom collate function.""" | |
| return { | |
| "mixture": torch.stack([item["mixture"] for item in batch]), | |
| "target": torch.stack([item["target"] for item in batch]), | |
| "prompt": [item["prompt"] for item in batch], | |
| "stem_name": [item["stem_name"] for item in batch], | |
| } |