Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """Dataset of music tracks with rich metadata. | |
| """ | |
| from dataclasses import dataclass, field, fields, replace | |
| import gzip | |
| import json | |
| import logging | |
| from pathlib import Path | |
| import random | |
| import typing as tp | |
| import pretty_midi | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from .btc_chords import Chords | |
| from .info_audio_dataset import ( | |
| InfoAudioDataset, | |
| AudioInfo, | |
| get_keyword_list, | |
| get_keyword, | |
| get_string | |
| ) | |
| from ..modules.conditioners import ( | |
| ConditioningAttributes, | |
| JointEmbedCondition, | |
| WavCondition, | |
| ChordCondition, | |
| BeatCondition | |
| ) | |
| from ..utils.utils import warn_once | |
| logger = logging.getLogger(__name__) | |
| CHORDS = Chords() | |
| class MusicInfo(AudioInfo): | |
| """Segment info augmented with music metadata. | |
| """ | |
| # music-specific metadata | |
| title: tp.Optional[str] = None | |
| artist: tp.Optional[str] = None # anonymized artist id, used to ensure no overlap between splits | |
| key: tp.Optional[str] = None | |
| bpm: tp.Optional[float] = None | |
| genre: tp.Optional[str] = None | |
| moods: tp.Optional[list] = None | |
| keywords: tp.Optional[list] = None | |
| description: tp.Optional[str] = None | |
| name: tp.Optional[str] = None | |
| instrument: tp.Optional[str] = None | |
| chord: tp.Optional[ChordCondition] = None | |
| beat: tp.Optional[BeatCondition] = None | |
| # original wav accompanying the metadata | |
| self_wav: tp.Optional[WavCondition] = None | |
| # dict mapping attributes names to tuple of wav, text and metadata | |
| joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict) | |
| def has_music_meta(self) -> bool: | |
| return self.name is not None | |
| def to_condition_attributes(self) -> ConditioningAttributes: | |
| out = ConditioningAttributes() | |
| for _field in fields(self): | |
| key, value = _field.name, getattr(self, _field.name) | |
| if key == 'self_wav': | |
| out.wav[key] = value | |
| elif key == 'chord': | |
| out.chord[key] = value | |
| elif key == 'beat': | |
| out.beat[key] = value | |
| elif key == 'joint_embed': | |
| for embed_attribute, embed_cond in value.items(): | |
| out.joint_embed[embed_attribute] = embed_cond | |
| else: | |
| if isinstance(value, list): | |
| value = ' '.join(value) | |
| out.text[key] = value | |
| return out | |
| def attribute_getter(attribute): | |
| if attribute == 'bpm': | |
| preprocess_func = get_bpm | |
| elif attribute == 'key': | |
| preprocess_func = get_musical_key | |
| elif attribute in ['moods', 'keywords']: | |
| preprocess_func = get_keyword_list | |
| elif attribute in ['genre', 'name', 'instrument']: | |
| preprocess_func = get_keyword | |
| elif attribute in ['title', 'artist', 'description']: | |
| preprocess_func = get_string | |
| else: | |
| preprocess_func = None | |
| return preprocess_func | |
| def from_dict(cls, dictionary: dict, fields_required: bool = False): | |
| _dictionary: tp.Dict[str, tp.Any] = {} | |
| # allow a subset of attributes to not be loaded from the dictionary | |
| # these attributes may be populated later | |
| post_init_attributes = ['self_wav', 'chord', 'beat', 'joint_embed'] | |
| optional_fields = ['keywords'] | |
| for _field in fields(cls): | |
| if _field.name in post_init_attributes: | |
| continue | |
| elif _field.name not in dictionary: | |
| if fields_required and _field.name not in optional_fields: | |
| raise KeyError(f"Unexpected missing key: {_field.name}") | |
| else: | |
| preprocess_func: tp.Optional[tp.Callable] = cls.attribute_getter(_field.name) | |
| value = dictionary[_field.name] | |
| if preprocess_func: | |
| value = preprocess_func(value) | |
| _dictionary[_field.name] = value | |
| return cls(**_dictionary) | |
| def augment_music_info_description(music_info: MusicInfo, merge_text_p: float = 0., | |
| drop_desc_p: float = 0., drop_other_p: float = 0.) -> MusicInfo: | |
| """Augment MusicInfo description with additional metadata fields and potential dropout. | |
| Additional textual attributes are added given probability 'merge_text_conditions_p' and | |
| the original textual description is dropped from the augmented description given probability drop_desc_p. | |
| Args: | |
| music_info (MusicInfo): The music metadata to augment. | |
| merge_text_p (float): Probability of merging additional metadata to the description. | |
| If provided value is 0, then no merging is performed. | |
| drop_desc_p (float): Probability of dropping the original description on text merge. | |
| if provided value is 0, then no drop out is performed. | |
| drop_other_p (float): Probability of dropping the other fields used for text augmentation. | |
| Returns: | |
| MusicInfo: The MusicInfo with augmented textual description. | |
| """ | |
| def is_valid_field(field_name: str, field_value: tp.Any) -> bool: | |
| valid_field_name = field_name in ['key', 'bpm', 'genre', 'moods', 'instrument', 'keywords'] | |
| valid_field_value = field_value is not None and isinstance(field_value, (int, float, str, list)) | |
| keep_field = random.uniform(0, 1) < drop_other_p | |
| return valid_field_name and valid_field_value and keep_field | |
| def process_value(v: tp.Any) -> str: | |
| if isinstance(v, (int, float, str)): | |
| return str(v) | |
| if isinstance(v, list): | |
| return ", ".join(v) | |
| else: | |
| raise ValueError(f"Unknown type for text value! ({type(v), v})") | |
| description = music_info.description | |
| metadata_text = "" | |
| # metadata_text = "rock style music, consistent rhythm, catchy song." | |
| if random.uniform(0, 1) < merge_text_p: | |
| meta_pairs = [f'{_field.name}: {process_value(getattr(music_info, _field.name))}' | |
| for _field in fields(music_info) if is_valid_field(_field.name, getattr(music_info, _field.name))] | |
| random.shuffle(meta_pairs) | |
| metadata_text = ". ".join(meta_pairs) | |
| description = description if not random.uniform(0, 1) < drop_desc_p else None | |
| logger.debug(f"Applying text augmentation on MMI info. description: {description}, metadata: {metadata_text}") | |
| if description is None: | |
| description = metadata_text if len(metadata_text) > 1 else None | |
| else: | |
| description = ". ".join([description.rstrip('.'), metadata_text]) | |
| description = description.strip() if description else None | |
| music_info = replace(music_info) | |
| music_info.description = description | |
| return music_info | |
| class Paraphraser: | |
| def __init__(self, paraphrase_source: tp.Union[str, Path], paraphrase_p: float = 0.): | |
| self.paraphrase_p = paraphrase_p | |
| open_fn = gzip.open if str(paraphrase_source).lower().endswith('.gz') else open | |
| with open_fn(paraphrase_source, 'rb') as f: # type: ignore | |
| self.paraphrase_source = json.loads(f.read()) | |
| logger.info(f"loaded paraphrasing source from: {paraphrase_source}") | |
| def sample_paraphrase(self, audio_path: str, description: str): | |
| if random.random() >= self.paraphrase_p: | |
| return description | |
| info_path = Path(audio_path).with_suffix('.json') | |
| if info_path not in self.paraphrase_source: | |
| warn_once(logger, f"{info_path} not in paraphrase source!") | |
| return description | |
| new_desc = random.choice(self.paraphrase_source[info_path]) | |
| logger.debug(f"{description} -> {new_desc}") | |
| return new_desc | |
| class MusicDataset(InfoAudioDataset): | |
| """Music dataset is an AudioDataset with music-related metadata. | |
| Args: | |
| info_fields_required (bool): Whether to enforce having required fields. | |
| merge_text_p (float): Probability of merging additional metadata to the description. | |
| drop_desc_p (float): Probability of dropping the original description on text merge. | |
| drop_other_p (float): Probability of dropping the other fields used for text augmentation. | |
| joint_embed_attributes (list[str]): A list of attributes for which joint embedding metadata is returned. | |
| paraphrase_source (str, optional): Path to the .json or .json.gz file containing the | |
| paraphrases for the description. The json should be a dict with keys are the | |
| original info path (e.g. track_path.json) and each value is a list of possible | |
| paraphrased. | |
| paraphrase_p (float): probability of taking a paraphrase. | |
| See `audiocraft.data.info_audio_dataset.InfoAudioDataset` for full initialization arguments. | |
| """ | |
| def __init__(self, *args, info_fields_required: bool = True, | |
| merge_text_p: float = 0., drop_desc_p: float = 0., drop_other_p: float = 0., | |
| joint_embed_attributes: tp.List[str] = [], | |
| paraphrase_source: tp.Optional[str] = None, paraphrase_p: float = 0, | |
| **kwargs): | |
| kwargs['return_info'] = True # We require the info for each song of the dataset. | |
| super().__init__(*args, **kwargs) | |
| self.info_fields_required = info_fields_required | |
| self.merge_text_p = merge_text_p | |
| self.drop_desc_p = drop_desc_p | |
| self.drop_other_p = drop_other_p | |
| self.joint_embed_attributes = joint_embed_attributes | |
| self.paraphraser = None | |
| self.downsample_rate = 640 | |
| self.sr = 32000 | |
| if paraphrase_source is not None: | |
| self.paraphraser = Paraphraser(paraphrase_source, paraphrase_p) | |
| def __getitem__(self, index): | |
| wav, info = super().__getitem__(index) # wav_seg and seg_info | |
| info_data = info.to_dict() | |
| # unpack info | |
| target_sr = self.sr | |
| n_frames_wave = info.n_frames | |
| n_frames_feat = int(info.n_frames // self.downsample_rate) | |
| music_info_path = str(info.meta.path).replace('no_vocal.wav', 'tags.json') | |
| chord_path = str(info.meta.path).replace('no_vocal.wav', 'chord.lab') | |
| beats_path = str(info.meta.path).replace('no_vocal.wav', 'beats.npy') | |
| if all([ | |
| not Path(music_info_path).exists(), | |
| not Path(beats_path).exists(), | |
| not Path(chord_path).exists(), | |
| ]): | |
| raise FileNotFoundError | |
| ### music info | |
| with open(music_info_path, 'r') as json_file: | |
| music_data = json.load(json_file) | |
| music_data.update(info_data) | |
| music_info = MusicInfo.from_dict(music_data, fields_required=self.info_fields_required) | |
| if self.paraphraser is not None: | |
| music_info.description = self.paraphraser.sample(music_info.meta.path, music_info.description) | |
| if self.merge_text_p: | |
| music_info = augment_music_info_description( | |
| music_info, self.merge_text_p, self.drop_desc_p, self.drop_other_p) | |
| ### load features to tensors ### | |
| feat_hz = target_sr/self.downsample_rate | |
| ## beat&bar: 2 x T | |
| feat_beats = np.zeros((2, n_frames_feat)) | |
| beats_np = np.load(beats_path) | |
| beat_time = beats_np[:, 0] | |
| bar_time = beats_np[np.where(beats_np[:, 1] == 1)[0], 0] | |
| beat_frame = [ | |
| int((t-info.seek_time)*feat_hz) for t in beat_time | |
| if (t >= info.seek_time and t < info.seek_time + self.segment_duration)] | |
| bar_frame =[ | |
| int((t-info.seek_time)*feat_hz) for t in bar_time | |
| if (t >= info.seek_time and t < info.seek_time + self.segment_duration)] | |
| feat_beats[0, beat_frame] = 1 | |
| feat_beats[1, bar_frame] = 1 | |
| kernel = np.array([0.05, 0.1, 0.3, 0.9, 0.3, 0.1, 0.05]) | |
| feat_beats[0] = np.convolve(feat_beats[0] , kernel, 'same') # apply soft kernel | |
| beat_events = feat_beats[0] + feat_beats[1] | |
| beat_events = torch.tensor(beat_events).unsqueeze(0) # [T] -> [1, T] | |
| music_info.beat = BeatCondition(beat=beat_events[None], length=torch.tensor([n_frames_feat]), | |
| bpm=[music_data["bpm"]], path=[music_info_path], seek_frame=[info.seek_time*target_sr//self.downsample_rate]) | |
| ## chord: 12 x T | |
| feat_chord = np.zeros((12, n_frames_feat)) # root| ivs | |
| with open(chord_path, 'r') as f: | |
| for line in f.readlines(): | |
| splits = line.split() | |
| if len(splits) == 3: | |
| st_sec, ed_sec, ctag = splits | |
| st_sec = float(st_sec) - info.seek_time | |
| ed_sec = float(ed_sec) - info.seek_time | |
| st_frame = int(st_sec*feat_hz) | |
| ed_frame = int(ed_sec*feat_hz) | |
| # 12 chorma | |
| mhot = CHORDS.chord(ctag) | |
| final_vec = np.roll(mhot[2], mhot[0]) | |
| final_vec = final_vec[..., None] | |
| feat_chord[:, st_frame:ed_frame] = final_vec | |
| feat_chord = torch.from_numpy(feat_chord) | |
| music_info.chord = ChordCondition( | |
| chord=feat_chord[None], length=torch.tensor([n_frames_feat]), | |
| bpm=[music_data["bpm"]], path=[chord_path], seek_frame=[info.seek_time*self.sr//self.downsample_rate]) | |
| music_info.self_wav = WavCondition( | |
| wav=wav[None], length=torch.tensor([info.n_frames]), | |
| sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time]) | |
| for att in self.joint_embed_attributes: | |
| att_value = getattr(music_info, att) | |
| joint_embed_cond = JointEmbedCondition( | |
| wav[None], [att_value], torch.tensor([info.n_frames]), | |
| sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time]) | |
| music_info.joint_embed[att] = joint_embed_cond | |
| return wav, music_info | |
| def get_musical_key(value: tp.Optional[str]) -> tp.Optional[str]: | |
| """Preprocess key keywords, discarding them if there are multiple key defined.""" | |
| if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None': | |
| return None | |
| elif ',' in value: | |
| # For now, we discard when multiple keys are defined separated with comas | |
| return None | |
| else: | |
| return value.strip().lower() | |
| def get_bpm(value: tp.Optional[str]) -> tp.Optional[float]: | |
| """Preprocess to a float.""" | |
| if value is None: | |
| return None | |
| try: | |
| return float(value) | |
| except ValueError: | |
| return None | |