diff --git a/AIMeiSheng/demucs/__init__.py b/AIMeiSheng/demucs/__init__.py new file mode 100644 index 0000000..e02c0ad --- /dev/null +++ b/AIMeiSheng/demucs/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +__version__ = "4.1.0a2" diff --git a/AIMeiSheng/demucs/__main__.py b/AIMeiSheng/demucs/__main__.py new file mode 100644 index 0000000..da0a541 --- /dev/null +++ b/AIMeiSheng/demucs/__main__.py @@ -0,0 +1,10 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .separate import main + +if __name__ == '__main__': + main() diff --git a/AIMeiSheng/demucs/api.py b/AIMeiSheng/demucs/api.py new file mode 100644 index 0000000..20079a6 --- /dev/null +++ b/AIMeiSheng/demucs/api.py @@ -0,0 +1,392 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""API methods for demucs + +Classes +------- +`demucs.api.Separator`: The base separator class + +Functions +--------- +`demucs.api.save_audio`: Save an audio +`demucs.api.list_models`: Get models list + +Examples +-------- +See the end of this module (if __name__ == "__main__") +""" + +import subprocess + +import torch as th +import torchaudio as ta + +from dora.log import fatal +from pathlib import Path +from typing import Optional, Callable, Dict, Tuple, Union + +from .apply import apply_model, _replace_dict +from .audio import AudioFile, convert_audio, save_audio +from .pretrained import get_model, _parse_remote_files, REMOTE_ROOT +from .repo import RemoteRepo, LocalRepo, ModelOnlyRepo, BagOnlyRepo + + +class LoadAudioError(Exception): + pass + + +class LoadModelError(Exception): + pass + + +class _NotProvided: + pass + + +NotProvided = _NotProvided() + + +class Separator: + def __init__( + self, + model: str = "htdemucs", + repo: Optional[Path] = None, + device: str = "cuda" if th.cuda.is_available() else "cpu", + shifts: int = 1, + overlap: float = 0.25, + split: bool = True, + segment: Optional[int] = None, + jobs: int = 0, + progress: bool = False, + callback: Optional[Callable[[dict], None]] = None, + callback_arg: Optional[dict] = None, + ): + """ + `class Separator` + ================= + + Parameters + ---------- + model: Pretrained model name or signature. Default is htdemucs. + repo: Folder containing all pre-trained models for use. + segment: Length (in seconds) of each segment (only available if `split` is `True`). If \ + not specified, will use the command line option. + shifts: If > 0, will shift in time `wav` by a random amount between 0 and 0.5 sec and \ + apply the oppositve shift to the output. This is repeated `shifts` time and all \ + predictions are averaged. This effectively makes the model time equivariant and \ + improves SDR by up to 0.2 points. If not specified, will use the command line option. + split: If True, the input will be broken down into small chunks (length set by `segment`) \ + and predictions will be performed individually on each and concatenated. Useful for \ + model with large memory footprint like Tasnet. If not specified, will use the command \ + line option. + overlap: The overlap between the splits. If not specified, will use the command line \ + option. + device (torch.device, str, or None): If provided, device on which to execute the \ + computation, otherwise `wav.device` is assumed. When `device` is different from \ + `wav.device`, only local computations will be on `device`, while the entire tracks \ + will be stored on `wav.device`. If not specified, will use the command line option. + jobs: Number of jobs. This can increase memory usage but will be much faster when \ + multiple cores are available. If not specified, will use the command line option. + callback: A function will be called when the separation of a chunk starts or finished. \ + The argument passed to the function will be a dict. For more information, please see \ + the Callback section. + callback_arg: A dict containing private parameters to be passed to callback function. For \ + more information, please see the Callback section. + progress: If true, show a progress bar. + + Callback + -------- + The function will be called with only one positional parameter whose type is `dict`. The + `callback_arg` will be combined with information of current separation progress. The + progress information will override the values in `callback_arg` if same key has been used. + To abort the separation, raise `KeyboardInterrupt`. + + Progress information contains several keys (These keys will always exist): + - `model_idx_in_bag`: The index of the submodel in `BagOfModels`. Starts from 0. + - `shift_idx`: The index of shifts. Starts from 0. + - `segment_offset`: The offset of current segment. If the number is 441000, it doesn't + mean that it is at the 441000 second of the audio, but the "frame" of the tensor. + - `state`: Could be `"start"` or `"end"`. + - `audio_length`: Length of the audio (in "frame" of the tensor). + - `models`: Count of submodels in the model. + """ + self._name = model + self._repo = repo + self._load_model() + self.update_parameter(device=device, shifts=shifts, overlap=overlap, split=split, + segment=segment, jobs=jobs, progress=progress, callback=callback, + callback_arg=callback_arg) + + def update_parameter( + self, + device: Union[str, _NotProvided] = NotProvided, + shifts: Union[int, _NotProvided] = NotProvided, + overlap: Union[float, _NotProvided] = NotProvided, + split: Union[bool, _NotProvided] = NotProvided, + segment: Optional[Union[int, _NotProvided]] = NotProvided, + jobs: Union[int, _NotProvided] = NotProvided, + progress: Union[bool, _NotProvided] = NotProvided, + callback: Optional[ + Union[Callable[[dict], None], _NotProvided] + ] = NotProvided, + callback_arg: Optional[Union[dict, _NotProvided]] = NotProvided, + ): + """ + Update the parameters of separation. + + Parameters + ---------- + segment: Length (in seconds) of each segment (only available if `split` is `True`). If \ + not specified, will use the command line option. + shifts: If > 0, will shift in time `wav` by a random amount between 0 and 0.5 sec and \ + apply the oppositve shift to the output. This is repeated `shifts` time and all \ + predictions are averaged. This effectively makes the model time equivariant and \ + improves SDR by up to 0.2 points. If not specified, will use the command line option. + split: If True, the input will be broken down into small chunks (length set by `segment`) \ + and predictions will be performed individually on each and concatenated. Useful for \ + model with large memory footprint like Tasnet. If not specified, will use the command \ + line option. + overlap: The overlap between the splits. If not specified, will use the command line \ + option. + device (torch.device, str, or None): If provided, device on which to execute the \ + computation, otherwise `wav.device` is assumed. When `device` is different from \ + `wav.device`, only local computations will be on `device`, while the entire tracks \ + will be stored on `wav.device`. If not specified, will use the command line option. + jobs: Number of jobs. This can increase memory usage but will be much faster when \ + multiple cores are available. If not specified, will use the command line option. + callback: A function will be called when the separation of a chunk starts or finished. \ + The argument passed to the function will be a dict. For more information, please see \ + the Callback section. + callback_arg: A dict containing private parameters to be passed to callback function. For \ + more information, please see the Callback section. + progress: If true, show a progress bar. + + Callback + -------- + The function will be called with only one positional parameter whose type is `dict`. The + `callback_arg` will be combined with information of current separation progress. The + progress information will override the values in `callback_arg` if same key has been used. + To abort the separation, raise `KeyboardInterrupt`. + + Progress information contains several keys (These keys will always exist): + - `model_idx_in_bag`: The index of the submodel in `BagOfModels`. Starts from 0. + - `shift_idx`: The index of shifts. Starts from 0. + - `segment_offset`: The offset of current segment. If the number is 441000, it doesn't + mean that it is at the 441000 second of the audio, but the "frame" of the tensor. + - `state`: Could be `"start"` or `"end"`. + - `audio_length`: Length of the audio (in "frame" of the tensor). + - `models`: Count of submodels in the model. + """ + if not isinstance(device, _NotProvided): + self._device = device + if not isinstance(shifts, _NotProvided): + self._shifts = shifts + if not isinstance(overlap, _NotProvided): + self._overlap = overlap + if not isinstance(split, _NotProvided): + self._split = split + if not isinstance(segment, _NotProvided): + self._segment = segment + if not isinstance(jobs, _NotProvided): + self._jobs = jobs + if not isinstance(progress, _NotProvided): + self._progress = progress + if not isinstance(callback, _NotProvided): + self._callback = callback + if not isinstance(callback_arg, _NotProvided): + self._callback_arg = callback_arg + + def _load_model(self): + self._model = get_model(name=self._name, repo=self._repo) + if self._model is None: + raise LoadModelError("Failed to load model") + self._audio_channels = self._model.audio_channels + self._samplerate = self._model.samplerate + + def _load_audio(self, track: Path): + errors = {} + wav = None + + try: + wav = AudioFile(track).read(streams=0, samplerate=self._samplerate, + channels=self._audio_channels) + except FileNotFoundError: + errors["ffmpeg"] = "FFmpeg is not installed." + except subprocess.CalledProcessError: + errors["ffmpeg"] = "FFmpeg could not read the file." + + if wav is None: + try: + wav, sr = ta.load(str(track)) + except RuntimeError as err: + errors["torchaudio"] = err.args[0] + else: + wav = convert_audio(wav, sr, self._samplerate, self._audio_channels) + + if wav is None: + raise LoadAudioError( + "\n".join( + "When trying to load using {}, got the following error: {}".format( + backend, error + ) + for backend, error in errors.items() + ) + ) + return wav + + def separate_tensor( + self, wav: th.Tensor, sr: Optional[int] = None + ) -> Tuple[th.Tensor, Dict[str, th.Tensor]]: + """ + Separate a loaded tensor. + + Parameters + ---------- + wav: Waveform of the audio. Should have 2 dimensions, the first is each audio channel, \ + while the second is the waveform of each channel. Type should be float32. \ + e.g. `tuple(wav.shape) == (2, 884000)` means the audio has 2 channels. + sr: Sample rate of the original audio, the wave will be resampled if it doesn't match the \ + model. + + Returns + ------- + A tuple, whose first element is the original wave and second element is a dict, whose keys + are the name of stems and values are separated waves. The original wave will have already + been resampled. + + Notes + ----- + Use this function with cautiousness. This function does not provide data verifying. + """ + if sr is not None and sr != self.samplerate: + wav = convert_audio(wav, sr, self._samplerate, self._audio_channels) + ref = wav.mean(0) + wav -= ref.mean() + wav /= ref.std() + 1e-8 + out = apply_model( + self._model, + wav[None], + segment=self._segment, + shifts=self._shifts, + split=self._split, + overlap=self._overlap, + device=self._device, + num_workers=self._jobs, + callback=self._callback, + callback_arg=_replace_dict( + self._callback_arg, ("audio_length", wav.shape[1]) + ), + progress=self._progress, + ) + if out is None: + raise KeyboardInterrupt + out *= ref.std() + 1e-8 + out += ref.mean() + wav *= ref.std() + 1e-8 + wav += ref.mean() + return (wav, dict(zip(self._model.sources, out[0]))) + + def separate_audio_file(self, file: Path): + """ + Separate an audio file. The method will automatically read the file. + + Parameters + ---------- + wav: Path of the file to be separated. + + Returns + ------- + A tuple, whose first element is the original wave and second element is a dict, whose keys + are the name of stems and values are separated waves. The original wave will have already + been resampled. + """ + return self.separate_tensor(self._load_audio(file), self.samplerate) + + @property + def samplerate(self): + return self._samplerate + + @property + def audio_channels(self): + return self._audio_channels + + @property + def model(self): + return self._model + + +def list_models(repo: Optional[Path] = None) -> Dict[str, Dict[str, Union[str, Path]]]: + """ + List the available models. Please remember that not all the returned models can be + successfully loaded. + + Parameters + ---------- + repo: The repo whose models are to be listed. + + Returns + ------- + A dict with two keys ("single" for single models and "bag" for bag of models). The values are + lists whose components are strs. + """ + model_repo: ModelOnlyRepo + if repo is None: + models = _parse_remote_files(REMOTE_ROOT / 'files.txt') + model_repo = RemoteRepo(models) + bag_repo = BagOnlyRepo(REMOTE_ROOT, model_repo) + else: + if not repo.is_dir(): + fatal(f"{repo} must exist and be a directory.") + model_repo = LocalRepo(repo) + bag_repo = BagOnlyRepo(repo, model_repo) + return {"single": model_repo.list_model(), "bag": bag_repo.list_model()} + + +if __name__ == "__main__": + # Test API functions + # two-stem not supported + + from .separate import get_parser + + args = get_parser().parse_args() + separator = Separator( + model=args.name, + repo=args.repo, + device=args.device, + shifts=args.shifts, + overlap=args.overlap, + split=args.split, + segment=args.segment, + jobs=args.jobs, + callback=print + ) + out = args.out / args.name + out.mkdir(parents=True, exist_ok=True) + for file in args.tracks: + separated = separator.separate_audio_file(file)[1] + if args.mp3: + ext = "mp3" + elif args.flac: + ext = "flac" + else: + ext = "wav" + kwargs = { + "samplerate": separator.samplerate, + "bitrate": args.mp3_bitrate, + "clip": args.clip_mode, + "as_float": args.float32, + "bits_per_sample": 24 if args.int24 else 16, + } + for stem, source in separated.items(): + stem = out / args.filename.format( + track=Path(file).name.rsplit(".", 1)[0], + trackext=Path(file).name.rsplit(".", 1)[-1], + stem=stem, + ext=ext, + ) + stem.parent.mkdir(parents=True, exist_ok=True) + save_audio(source, str(stem), **kwargs) diff --git a/AIMeiSheng/demucs/apply.py b/AIMeiSheng/demucs/apply.py new file mode 100644 index 0000000..1540f3d --- /dev/null +++ b/AIMeiSheng/demucs/apply.py @@ -0,0 +1,322 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +""" +Code to apply a model to a mix. It will handle chunking with overlaps and +inteprolation between chunks, as well as the "shift trick". +""" +from concurrent.futures import ThreadPoolExecutor +import copy +import random +from threading import Lock +import typing as tp + +import torch as th +from torch import nn +from torch.nn import functional as F +import tqdm + +from .demucs import Demucs +from .hdemucs import HDemucs +from .htdemucs import HTDemucs +from .utils import center_trim, DummyPoolExecutor + +Model = tp.Union[Demucs, HDemucs, HTDemucs] + + +class BagOfModels(nn.Module): + def __init__(self, models: tp.List[Model], + weights: tp.Optional[tp.List[tp.List[float]]] = None, + segment: tp.Optional[float] = None): + """ + Represents a bag of models with specific weights. + You should call `apply_model` rather than calling directly the forward here for + optimal performance. + + Args: + models (list[nn.Module]): list of Demucs/HDemucs models. + weights (list[list[float]]): list of weights. If None, assumed to + be all ones, otherwise it should be a list of N list (N number of models), + each containing S floats (S number of sources). + segment (None or float): overrides the `segment` attribute of each model + (this is performed inplace, be careful is you reuse the models passed). + """ + super().__init__() + assert len(models) > 0 + first = models[0] + for other in models: + assert other.sources == first.sources + assert other.samplerate == first.samplerate + assert other.audio_channels == first.audio_channels + if segment is not None: + if not isinstance(other, HTDemucs) and segment > other.segment: + other.segment = segment + + self.audio_channels = first.audio_channels + self.samplerate = first.samplerate + self.sources = first.sources + self.models = nn.ModuleList(models) + + if weights is None: + weights = [[1. for _ in first.sources] for _ in models] + else: + assert len(weights) == len(models) + for weight in weights: + assert len(weight) == len(first.sources) + self.weights = weights + + @property + def max_allowed_segment(self) -> float: + max_allowed_segment = float('inf') + for model in self.models: + if isinstance(model, HTDemucs): + max_allowed_segment = min(max_allowed_segment, float(model.segment)) + return max_allowed_segment + + def forward(self, x): + raise NotImplementedError("Call `apply_model` on this.") + + +class TensorChunk: + def __init__(self, tensor, offset=0, length=None): + total_length = tensor.shape[-1] + assert offset >= 0 + assert offset < total_length + + if length is None: + length = total_length - offset + else: + length = min(total_length - offset, length) + + if isinstance(tensor, TensorChunk): + self.tensor = tensor.tensor + self.offset = offset + tensor.offset + else: + self.tensor = tensor + self.offset = offset + self.length = length + self.device = tensor.device + + @property + def shape(self): + shape = list(self.tensor.shape) + shape[-1] = self.length + return shape + + def padded(self, target_length): + delta = target_length - self.length + total_length = self.tensor.shape[-1] + assert delta >= 0 + + start = self.offset - delta // 2 + end = start + target_length + + correct_start = max(0, start) + correct_end = min(total_length, end) + + pad_left = correct_start - start + pad_right = end - correct_end + + out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right)) + assert out.shape[-1] == target_length + return out + + +def tensor_chunk(tensor_or_chunk): + if isinstance(tensor_or_chunk, TensorChunk): + return tensor_or_chunk + else: + assert isinstance(tensor_or_chunk, th.Tensor) + return TensorChunk(tensor_or_chunk) + + +def _replace_dict(_dict: tp.Optional[dict], *subs: tp.Tuple[tp.Hashable, tp.Any]) -> dict: + if _dict is None: + _dict = {} + else: + _dict = copy.copy(_dict) + for key, value in subs: + _dict[key] = value + return _dict + + +def apply_model(model: tp.Union[BagOfModels, Model], + mix: tp.Union[th.Tensor, TensorChunk], + shifts: int = 1, split: bool = True, + overlap: float = 0.25, transition_power: float = 1., + progress: bool = False, device=None, + num_workers: int = 0, segment: tp.Optional[float] = None, + pool=None, lock=None, + callback: tp.Optional[tp.Callable[[dict], None]] = None, + callback_arg: tp.Optional[dict] = None) -> th.Tensor: + """ + Apply model to a given mixture. + + Args: + shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec + and apply the oppositve shift to the output. This is repeated `shifts` time and + all predictions are averaged. This effectively makes the model time equivariant + and improves SDR by up to 0.2 points. + split (bool): if True, the input will be broken down in 8 seconds extracts + and predictions will be performed individually on each and concatenated. + Useful for model with large memory footprint like Tasnet. + progress (bool): if True, show a progress bar (requires split=True) + device (torch.device, str, or None): if provided, device on which to + execute the computation, otherwise `mix.device` is assumed. + When `device` is different from `mix.device`, only local computations will + be on `device`, while the entire tracks will be stored on `mix.device`. + num_workers (int): if non zero, device is 'cpu', how many threads to + use in parallel. + segment (float or None): override the model segment parameter. + """ + if device is None: + device = mix.device + else: + device = th.device(device) + if pool is None: + if num_workers > 0 and device.type == 'cpu': + pool = ThreadPoolExecutor(num_workers) + else: + pool = DummyPoolExecutor() + if lock is None: + lock = Lock() + callback_arg = _replace_dict( + callback_arg, *{"model_idx_in_bag": 0, "shift_idx": 0, "segment_offset": 0}.items() + ) + kwargs: tp.Dict[str, tp.Any] = { + 'shifts': shifts, + 'split': split, + 'overlap': overlap, + 'transition_power': transition_power, + 'progress': progress, + 'device': device, + 'pool': pool, + 'segment': segment, + 'lock': lock, + } + out: tp.Union[float, th.Tensor] + res: tp.Union[float, th.Tensor] + if isinstance(model, BagOfModels): + # Special treatment for bag of model. + # We explicitely apply multiple times `apply_model` so that the random shifts + # are different for each model. + estimates: tp.Union[float, th.Tensor] = 0. + totals = [0.] * len(model.sources) + callback_arg["models"] = len(model.models) + for sub_model, model_weights in zip(model.models, model.weights): + kwargs["callback"] = (( + lambda d, i=callback_arg["model_idx_in_bag"]: callback( + _replace_dict(d, ("model_idx_in_bag", i))) if callback else None) + ) + original_model_device = next(iter(sub_model.parameters())).device + sub_model.to(device) + + res = apply_model(sub_model, mix, **kwargs, callback_arg=callback_arg) + out = res + sub_model.to(original_model_device) + for k, inst_weight in enumerate(model_weights): + out[:, k, :, :] *= inst_weight + totals[k] += inst_weight + estimates += out + del out + callback_arg["model_idx_in_bag"] += 1 + + assert isinstance(estimates, th.Tensor) + for k in range(estimates.shape[1]): + estimates[:, k, :, :] /= totals[k] + return estimates + + if "models" not in callback_arg: + callback_arg["models"] = 1 + model.to(device) + model.eval() + assert transition_power >= 1, "transition_power < 1 leads to weird behavior." + batch, channels, length = mix.shape + if shifts: + kwargs['shifts'] = 0 + max_shift = int(0.5 * model.samplerate) + mix = tensor_chunk(mix) + assert isinstance(mix, TensorChunk) + padded_mix = mix.padded(length + 2 * max_shift) + out = 0. + for shift_idx in range(shifts): + offset = random.randint(0, max_shift) + shifted = TensorChunk(padded_mix, offset, length + max_shift - offset) + kwargs["callback"] = ( + (lambda d, i=shift_idx: callback(_replace_dict(d, ("shift_idx", i))) + if callback else None) + ) + res = apply_model(model, shifted, **kwargs, callback_arg=callback_arg) + shifted_out = res + out += shifted_out[..., max_shift - offset:] + out /= shifts + assert isinstance(out, th.Tensor) + return out + elif split: + kwargs['split'] = False + out = th.zeros(batch, len(model.sources), channels, length, device=mix.device) + sum_weight = th.zeros(length, device=mix.device) + if segment is None: + segment = model.segment + assert segment is not None and segment > 0. + segment_length: int = int(model.samplerate * segment) + stride = int((1 - overlap) * segment_length) + offsets = range(0, length, stride) + scale = float(format(stride / model.samplerate, ".2f")) + # We start from a triangle shaped weight, with maximal weight in the middle + # of the segment. Then we normalize and take to the power `transition_power`. + # Large values of transition power will lead to sharper transitions. + weight = th.cat([th.arange(1, segment_length // 2 + 1, device=device), + th.arange(segment_length - segment_length // 2, 0, -1, device=device)]) + assert len(weight) == segment_length + # If the overlap < 50%, this will translate to linear transition when + # transition_power is 1. + weight = (weight / weight.max())**transition_power + futures = [] + for offset in offsets: + chunk = TensorChunk(mix, offset, segment_length) + future = pool.submit(apply_model, model, chunk, **kwargs, callback_arg=callback_arg, + callback=(lambda d, i=offset: + callback(_replace_dict(d, ("segment_offset", i))) + if callback else None)) + futures.append((future, offset)) + offset += segment_length + if progress: + futures = tqdm.tqdm(futures, unit_scale=scale, ncols=120, unit='seconds') + for future, offset in futures: + try: + chunk_out = future.result() # type: th.Tensor + except Exception: + pool.shutdown(wait=True, cancel_futures=True) + raise + chunk_length = chunk_out.shape[-1] + out[..., offset:offset + segment_length] += ( + weight[:chunk_length] * chunk_out).to(mix.device) + sum_weight[offset:offset + segment_length] += weight[:chunk_length].to(mix.device) + assert sum_weight.min() > 0 + out /= sum_weight + assert isinstance(out, th.Tensor) + return out + else: + valid_length: int + if isinstance(model, HTDemucs) and segment is not None: + valid_length = int(segment * model.samplerate) + elif hasattr(model, 'valid_length'): + valid_length = model.valid_length(length) # type: ignore + else: + valid_length = length + mix = tensor_chunk(mix) + assert isinstance(mix, TensorChunk) + padded_mix = mix.padded(valid_length).to(device) + with lock: + if callback is not None: + callback(_replace_dict(callback_arg, ("state", "start"))) # type: ignore + with th.no_grad(): + out = model(padded_mix) + with lock: + if callback is not None: + callback(_replace_dict(callback_arg, ("state", "end"))) # type: ignore + assert isinstance(out, th.Tensor) + return center_trim(out, length) diff --git a/AIMeiSheng/demucs/audio.py b/AIMeiSheng/demucs/audio.py new file mode 100644 index 0000000..31b29b3 --- /dev/null +++ b/AIMeiSheng/demucs/audio.py @@ -0,0 +1,265 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import json +import subprocess as sp +from pathlib import Path + +import lameenc +import julius +import numpy as np +import torch +import torchaudio as ta +import typing as tp + +from .utils import temp_filenames + + +def _read_info(path): + stdout_data = sp.check_output([ + 'ffprobe', "-loglevel", "panic", + str(path), '-print_format', 'json', '-show_format', '-show_streams' + ]) + return json.loads(stdout_data.decode('utf-8')) + + +class AudioFile: + """ + Allows to read audio from any format supported by ffmpeg, as well as resampling or + converting to mono on the fly. See :method:`read` for more details. + """ + def __init__(self, path: Path): + self.path = Path(path) + self._info = None + + def __repr__(self): + features = [("path", self.path)] + features.append(("samplerate", self.samplerate())) + features.append(("channels", self.channels())) + features.append(("streams", len(self))) + features_str = ", ".join(f"{name}={value}" for name, value in features) + return f"AudioFile({features_str})" + + @property + def info(self): + if self._info is None: + self._info = _read_info(self.path) + return self._info + + @property + def duration(self): + return float(self.info['format']['duration']) + + @property + def _audio_streams(self): + return [ + index for index, stream in enumerate(self.info["streams"]) + if stream["codec_type"] == "audio" + ] + + def __len__(self): + return len(self._audio_streams) + + def channels(self, stream=0): + return int(self.info['streams'][self._audio_streams[stream]]['channels']) + + def samplerate(self, stream=0): + return int(self.info['streams'][self._audio_streams[stream]]['sample_rate']) + + def read(self, + seek_time=None, + duration=None, + streams=slice(None), + samplerate=None, + channels=None): + """ + Slightly more efficient implementation than stempeg, + in particular, this will extract all stems at once + rather than having to loop over one file multiple times + for each stream. + + Args: + seek_time (float): seek time in seconds or None if no seeking is needed. + duration (float): duration in seconds to extract or None to extract until the end. + streams (slice, int or list): streams to extract, can be a single int, a list or + a slice. If it is a slice or list, the output will be of size [S, C, T] + with S the number of streams, C the number of channels and T the number of samples. + If it is an int, the output will be [C, T]. + samplerate (int): if provided, will resample on the fly. If None, no resampling will + be done. Original sampling rate can be obtained with :method:`samplerate`. + channels (int): if 1, will convert to mono. We do not rely on ffmpeg for that + as ffmpeg automatically scale by +3dB to conserve volume when playing on speakers. + See https://sound.stackexchange.com/a/42710. + Our definition of mono is simply the average of the two channels. Any other + value will be ignored. + """ + streams = np.array(range(len(self)))[streams] + single = not isinstance(streams, np.ndarray) + if single: + streams = [streams] + + if duration is None: + target_size = None + query_duration = None + else: + target_size = int((samplerate or self.samplerate()) * duration) + query_duration = float((target_size + 1) / (samplerate or self.samplerate())) + + with temp_filenames(len(streams)) as filenames: + command = ['ffmpeg', '-y'] + command += ['-loglevel', 'panic'] + if seek_time: + command += ['-ss', str(seek_time)] + command += ['-i', str(self.path)] + for stream, filename in zip(streams, filenames): + command += ['-map', f'0:{self._audio_streams[stream]}'] + if query_duration is not None: + command += ['-t', str(query_duration)] + command += ['-threads', '1'] + command += ['-f', 'f32le'] + if samplerate is not None: + command += ['-ar', str(samplerate)] + command += [filename] + + sp.run(command, check=True) + wavs = [] + for filename in filenames: + wav = np.fromfile(filename, dtype=np.float32) + wav = torch.from_numpy(wav) + wav = wav.view(-1, self.channels()).t() + if channels is not None: + wav = convert_audio_channels(wav, channels) + if target_size is not None: + wav = wav[..., :target_size] + wavs.append(wav) + wav = torch.stack(wavs, dim=0) + if single: + wav = wav[0] + return wav + + +def convert_audio_channels(wav, channels=2): + """Convert audio to the given number of channels.""" + *shape, src_channels, length = wav.shape + if src_channels == channels: + pass + elif channels == 1: + # Case 1: + # The caller asked 1-channel audio, but the stream have multiple + # channels, downmix all channels. + wav = wav.mean(dim=-2, keepdim=True) + elif src_channels == 1: + # Case 2: + # The caller asked for multiple channels, but the input file have + # one single channel, replicate the audio over all channels. + wav = wav.expand(*shape, channels, length) + elif src_channels >= channels: + # Case 3: + # The caller asked for multiple channels, and the input file have + # more channels than requested. In that case return the first channels. + wav = wav[..., :channels, :] + else: + # Case 4: What is a reasonable choice here? + raise ValueError('The audio file has less channels than requested but is not mono.') + return wav + + +def convert_audio(wav, from_samplerate, to_samplerate, channels) -> torch.Tensor: + """Convert audio from a given samplerate to a target one and target number of channels.""" + wav = convert_audio_channels(wav, channels) + return julius.resample_frac(wav, from_samplerate, to_samplerate) + + +def i16_pcm(wav): + """Convert audio to 16 bits integer PCM format.""" + if wav.dtype.is_floating_point: + return (wav.clamp_(-1, 1) * (2**15 - 1)).short() + else: + return wav + + +def f32_pcm(wav): + """Convert audio to float 32 bits PCM format.""" + if wav.dtype.is_floating_point: + return wav + else: + return wav.float() / (2**15 - 1) + + +def as_dtype_pcm(wav, dtype): + """Convert audio to either f32 pcm or i16 pcm depending on the given dtype.""" + if wav.dtype.is_floating_point: + return f32_pcm(wav) + else: + return i16_pcm(wav) + + +def encode_mp3(wav, path, samplerate=44100, bitrate=320, quality=2, verbose=False): + """Save given audio as mp3. This should work on all OSes.""" + C, T = wav.shape + wav = i16_pcm(wav) + encoder = lameenc.Encoder() + encoder.set_bit_rate(bitrate) + encoder.set_in_sample_rate(samplerate) + encoder.set_channels(C) + encoder.set_quality(quality) # 2-highest, 7-fastest + if not verbose: + encoder.silence() + wav = wav.data.cpu() + wav = wav.transpose(0, 1).numpy() + mp3_data = encoder.encode(wav.tobytes()) + mp3_data += encoder.flush() + with open(path, "wb") as f: + f.write(mp3_data) + + +def prevent_clip(wav, mode='rescale'): + """ + different strategies for avoiding raw clipping. + """ + if mode is None or mode == 'none': + return wav + assert wav.dtype.is_floating_point, "too late for clipping" + if mode == 'rescale': + wav = wav / max(1.01 * wav.abs().max(), 1) + elif mode == 'clamp': + wav = wav.clamp(-0.99, 0.99) + elif mode == 'tanh': + wav = torch.tanh(wav) + else: + raise ValueError(f"Invalid mode {mode}") + return wav + + +def save_audio(wav: torch.Tensor, + path: tp.Union[str, Path], + samplerate: int, + bitrate: int = 320, + clip: tp.Literal["rescale", "clamp", "tanh", "none"] = 'rescale', + bits_per_sample: tp.Literal[16, 24, 32] = 16, + as_float: bool = False, + preset: tp.Literal[2, 3, 4, 5, 6, 7] = 2): + """Save audio file, automatically preventing clipping if necessary + based on the given `clip` strategy. If the path ends in `.mp3`, this + will save as mp3 with the given `bitrate`. Use `preset` to set mp3 quality: + 2 for highest quality, 7 for fastest speed + """ + wav = prevent_clip(wav, mode=clip) + path = Path(path) + suffix = path.suffix.lower() + if suffix == ".mp3": + encode_mp3(wav, path, samplerate, bitrate, preset, verbose=True) + elif suffix == ".wav": + if as_float: + bits_per_sample = 32 + encoding = 'PCM_F' + else: + encoding = 'PCM_S' + ta.save(str(path), wav, sample_rate=samplerate, + encoding=encoding, bits_per_sample=bits_per_sample) + elif suffix == ".flac": + ta.save(str(path), wav, sample_rate=samplerate, bits_per_sample=bits_per_sample) + else: + raise ValueError(f"Invalid suffix for path: {suffix}") diff --git a/AIMeiSheng/demucs/augment.py b/AIMeiSheng/demucs/augment.py new file mode 100644 index 0000000..6dab7f1 --- /dev/null +++ b/AIMeiSheng/demucs/augment.py @@ -0,0 +1,111 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Data augmentations. +""" + +import random +import torch as th +from torch import nn + + +class Shift(nn.Module): + """ + Randomly shift audio in time by up to `shift` samples. + """ + def __init__(self, shift=8192, same=False): + super().__init__() + self.shift = shift + self.same = same + + def forward(self, wav): + batch, sources, channels, time = wav.size() + length = time - self.shift + if self.shift > 0: + if not self.training: + wav = wav[..., :length] + else: + srcs = 1 if self.same else sources + offsets = th.randint(self.shift, [batch, srcs, 1, 1], device=wav.device) + offsets = offsets.expand(-1, sources, channels, -1) + indexes = th.arange(length, device=wav.device) + wav = wav.gather(3, indexes + offsets) + return wav + + +class FlipChannels(nn.Module): + """ + Flip left-right channels. + """ + def forward(self, wav): + batch, sources, channels, time = wav.size() + if self.training and wav.size(2) == 2: + left = th.randint(2, (batch, sources, 1, 1), device=wav.device) + left = left.expand(-1, -1, -1, time) + right = 1 - left + wav = th.cat([wav.gather(2, left), wav.gather(2, right)], dim=2) + return wav + + +class FlipSign(nn.Module): + """ + Random sign flip. + """ + def forward(self, wav): + batch, sources, channels, time = wav.size() + if self.training: + signs = th.randint(2, (batch, sources, 1, 1), device=wav.device, dtype=th.float32) + wav = wav * (2 * signs - 1) + return wav + + +class Remix(nn.Module): + """ + Shuffle sources to make new mixes. + """ + def __init__(self, proba=1, group_size=4): + """ + Shuffle sources within one batch. + Each batch is divided into groups of size `group_size` and shuffling is done within + each group separatly. This allow to keep the same probability distribution no matter + the number of GPUs. Without this grouping, using more GPUs would lead to a higher + probability of keeping two sources from the same track together which can impact + performance. + """ + super().__init__() + self.proba = proba + self.group_size = group_size + + def forward(self, wav): + batch, streams, channels, time = wav.size() + device = wav.device + + if self.training and random.random() < self.proba: + group_size = self.group_size or batch + if batch % group_size != 0: + raise ValueError(f"Batch size {batch} must be divisible by group size {group_size}") + groups = batch // group_size + wav = wav.view(groups, group_size, streams, channels, time) + permutations = th.argsort(th.rand(groups, group_size, streams, 1, 1, device=device), + dim=1) + wav = wav.gather(1, permutations.expand(-1, -1, -1, channels, time)) + wav = wav.view(batch, streams, channels, time) + return wav + + +class Scale(nn.Module): + def __init__(self, proba=1., min=0.25, max=1.25): + super().__init__() + self.proba = proba + self.min = min + self.max = max + + def forward(self, wav): + batch, streams, channels, time = wav.size() + device = wav.device + if self.training and random.random() < self.proba: + scales = th.empty(batch, streams, 1, 1, device=device).uniform_(self.min, self.max) + wav *= scales + return wav diff --git a/AIMeiSheng/demucs/demucs.py b/AIMeiSheng/demucs/demucs.py new file mode 100644 index 0000000..f6a4305 --- /dev/null +++ b/AIMeiSheng/demucs/demucs.py @@ -0,0 +1,447 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +import typing as tp + +import julius +import torch +from torch import nn +from torch.nn import functional as F + +from .states import capture_init +from .utils import center_trim, unfold +from .transformer import LayerScale + + +class BLSTM(nn.Module): + """ + BiLSTM with same hidden units as input dim. + If `max_steps` is not None, input will be splitting in overlapping + chunks and the LSTM applied separately on each chunk. + """ + def __init__(self, dim, layers=1, max_steps=None, skip=False): + super().__init__() + assert max_steps is None or max_steps % 4 == 0 + self.max_steps = max_steps + self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim) + self.linear = nn.Linear(2 * dim, dim) + self.skip = skip + + def forward(self, x): + B, C, T = x.shape + y = x + framed = False + if self.max_steps is not None and T > self.max_steps: + width = self.max_steps + stride = width // 2 + frames = unfold(x, width, stride) + nframes = frames.shape[2] + framed = True + x = frames.permute(0, 2, 1, 3).reshape(-1, C, width) + + x = x.permute(2, 0, 1) + + x = self.lstm(x)[0] + x = self.linear(x) + x = x.permute(1, 2, 0) + if framed: + out = [] + frames = x.reshape(B, -1, C, width) + limit = stride // 2 + for k in range(nframes): + if k == 0: + out.append(frames[:, k, :, :-limit]) + elif k == nframes - 1: + out.append(frames[:, k, :, limit:]) + else: + out.append(frames[:, k, :, limit:-limit]) + out = torch.cat(out, -1) + out = out[..., :T] + x = out + if self.skip: + x = x + y + return x + + +def rescale_conv(conv, reference): + """Rescale initial weight scale. It is unclear why it helps but it certainly does. + """ + std = conv.weight.std().detach() + scale = (std / reference)**0.5 + conv.weight.data /= scale + if conv.bias is not None: + conv.bias.data /= scale + + +def rescale_module(module, reference): + for sub in module.modules(): + if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d, nn.Conv2d, nn.ConvTranspose2d)): + rescale_conv(sub, reference) + + +class DConv(nn.Module): + """ + New residual branches in each encoder layer. + This alternates dilated convolutions, potentially with LSTMs and attention. + Also before entering each residual branch, dimension is projected on a smaller subspace, + e.g. of dim `channels // compress`. + """ + def __init__(self, channels: int, compress: float = 4, depth: int = 2, init: float = 1e-4, + norm=True, attn=False, heads=4, ndecay=4, lstm=False, gelu=True, + kernel=3, dilate=True): + """ + Args: + channels: input/output channels for residual branch. + compress: amount of channel compression inside the branch. + depth: number of layers in the residual branch. Each layer has its own + projection, and potentially LSTM and attention. + init: initial scale for LayerNorm. + norm: use GroupNorm. + attn: use LocalAttention. + heads: number of heads for the LocalAttention. + ndecay: number of decay controls in the LocalAttention. + lstm: use LSTM. + gelu: Use GELU activation. + kernel: kernel size for the (dilated) convolutions. + dilate: if true, use dilation, increasing with the depth. + """ + + super().__init__() + assert kernel % 2 == 1 + self.channels = channels + self.compress = compress + self.depth = abs(depth) + dilate = depth > 0 + + norm_fn: tp.Callable[[int], nn.Module] + norm_fn = lambda d: nn.Identity() # noqa + if norm: + norm_fn = lambda d: nn.GroupNorm(1, d) # noqa + + hidden = int(channels / compress) + + act: tp.Type[nn.Module] + if gelu: + act = nn.GELU + else: + act = nn.ReLU + + self.layers = nn.ModuleList([]) + for d in range(self.depth): + dilation = 2 ** d if dilate else 1 + padding = dilation * (kernel // 2) + mods = [ + nn.Conv1d(channels, hidden, kernel, dilation=dilation, padding=padding), + norm_fn(hidden), act(), + nn.Conv1d(hidden, 2 * channels, 1), + norm_fn(2 * channels), nn.GLU(1), + LayerScale(channels, init), + ] + if attn: + mods.insert(3, LocalState(hidden, heads=heads, ndecay=ndecay)) + if lstm: + mods.insert(3, BLSTM(hidden, layers=2, max_steps=200, skip=True)) + layer = nn.Sequential(*mods) + self.layers.append(layer) + + def forward(self, x): + for layer in self.layers: + x = x + layer(x) + return x + + +class LocalState(nn.Module): + """Local state allows to have attention based only on data (no positional embedding), + but while setting a constraint on the time window (e.g. decaying penalty term). + + Also a failed experiments with trying to provide some frequency based attention. + """ + def __init__(self, channels: int, heads: int = 4, nfreqs: int = 0, ndecay: int = 4): + super().__init__() + assert channels % heads == 0, (channels, heads) + self.heads = heads + self.nfreqs = nfreqs + self.ndecay = ndecay + self.content = nn.Conv1d(channels, channels, 1) + self.query = nn.Conv1d(channels, channels, 1) + self.key = nn.Conv1d(channels, channels, 1) + if nfreqs: + self.query_freqs = nn.Conv1d(channels, heads * nfreqs, 1) + if ndecay: + self.query_decay = nn.Conv1d(channels, heads * ndecay, 1) + # Initialize decay close to zero (there is a sigmoid), for maximum initial window. + self.query_decay.weight.data *= 0.01 + assert self.query_decay.bias is not None # stupid type checker + self.query_decay.bias.data[:] = -2 + self.proj = nn.Conv1d(channels + heads * nfreqs, channels, 1) + + def forward(self, x): + B, C, T = x.shape + heads = self.heads + indexes = torch.arange(T, device=x.device, dtype=x.dtype) + # left index are keys, right index are queries + delta = indexes[:, None] - indexes[None, :] + + queries = self.query(x).view(B, heads, -1, T) + keys = self.key(x).view(B, heads, -1, T) + # t are keys, s are queries + dots = torch.einsum("bhct,bhcs->bhts", keys, queries) + dots /= keys.shape[2]**0.5 + if self.nfreqs: + periods = torch.arange(1, self.nfreqs + 1, device=x.device, dtype=x.dtype) + freq_kernel = torch.cos(2 * math.pi * delta / periods.view(-1, 1, 1)) + freq_q = self.query_freqs(x).view(B, heads, -1, T) / self.nfreqs ** 0.5 + dots += torch.einsum("fts,bhfs->bhts", freq_kernel, freq_q) + if self.ndecay: + decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype) + decay_q = self.query_decay(x).view(B, heads, -1, T) + decay_q = torch.sigmoid(decay_q) / 2 + decay_kernel = - decays.view(-1, 1, 1) * delta.abs() / self.ndecay**0.5 + dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q) + + # Kill self reference. + dots.masked_fill_(torch.eye(T, device=dots.device, dtype=torch.bool), -100) + weights = torch.softmax(dots, dim=2) + + content = self.content(x).view(B, heads, -1, T) + result = torch.einsum("bhts,bhct->bhcs", weights, content) + if self.nfreqs: + time_sig = torch.einsum("bhts,fts->bhfs", weights, freq_kernel) + result = torch.cat([result, time_sig], 2) + result = result.reshape(B, -1, T) + return x + self.proj(result) + + +class Demucs(nn.Module): + @capture_init + def __init__(self, + sources, + # Channels + audio_channels=2, + channels=64, + growth=2., + # Main structure + depth=6, + rewrite=True, + lstm_layers=0, + # Convolutions + kernel_size=8, + stride=4, + context=1, + # Activations + gelu=True, + glu=True, + # Normalization + norm_starts=4, + norm_groups=4, + # DConv residual branch + dconv_mode=1, + dconv_depth=2, + dconv_comp=4, + dconv_attn=4, + dconv_lstm=4, + dconv_init=1e-4, + # Pre/post processing + normalize=True, + resample=True, + # Weight init + rescale=0.1, + # Metadata + samplerate=44100, + segment=4 * 10): + """ + Args: + sources (list[str]): list of source names + audio_channels (int): stereo or mono + channels (int): first convolution channels + depth (int): number of encoder/decoder layers + growth (float): multiply (resp divide) number of channels by that + for each layer of the encoder (resp decoder) + depth (int): number of layers in the encoder and in the decoder. + rewrite (bool): add 1x1 convolution to each layer. + lstm_layers (int): number of lstm layers, 0 = no lstm. Deactivated + by default, as this is now replaced by the smaller and faster small LSTMs + in the DConv branches. + kernel_size (int): kernel size for convolutions + stride (int): stride for convolutions + context (int): kernel size of the convolution in the + decoder before the transposed convolution. If > 1, + will provide some context from neighboring time steps. + gelu: use GELU activation function. + glu (bool): use glu instead of ReLU for the 1x1 rewrite conv. + norm_starts: layer at which group norm starts being used. + decoder layers are numbered in reverse order. + norm_groups: number of groups for group norm. + dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both. + dconv_depth: depth of residual DConv branch. + dconv_comp: compression of DConv branch. + dconv_attn: adds attention layers in DConv branch starting at this layer. + dconv_lstm: adds a LSTM layer in DConv branch starting at this layer. + dconv_init: initial scale for the DConv branch LayerScale. + normalize (bool): normalizes the input audio on the fly, and scales back + the output by the same amount. + resample (bool): upsample x2 the input and downsample /2 the output. + rescale (float): rescale initial weights of convolutions + to get their standard deviation closer to `rescale`. + samplerate (int): stored as meta information for easing + future evaluations of the model. + segment (float): duration of the chunks of audio to ideally evaluate the model on. + This is used by `demucs.apply.apply_model`. + """ + + super().__init__() + self.audio_channels = audio_channels + self.sources = sources + self.kernel_size = kernel_size + self.context = context + self.stride = stride + self.depth = depth + self.resample = resample + self.channels = channels + self.normalize = normalize + self.samplerate = samplerate + self.segment = segment + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + self.skip_scales = nn.ModuleList() + + if glu: + activation = nn.GLU(dim=1) + ch_scale = 2 + else: + activation = nn.ReLU() + ch_scale = 1 + if gelu: + act2 = nn.GELU + else: + act2 = nn.ReLU + + in_channels = audio_channels + padding = 0 + for index in range(depth): + norm_fn = lambda d: nn.Identity() # noqa + if index >= norm_starts: + norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa + + encode = [] + encode += [ + nn.Conv1d(in_channels, channels, kernel_size, stride), + norm_fn(channels), + act2(), + ] + attn = index >= dconv_attn + lstm = index >= dconv_lstm + if dconv_mode & 1: + encode += [DConv(channels, depth=dconv_depth, init=dconv_init, + compress=dconv_comp, attn=attn, lstm=lstm)] + if rewrite: + encode += [ + nn.Conv1d(channels, ch_scale * channels, 1), + norm_fn(ch_scale * channels), activation] + self.encoder.append(nn.Sequential(*encode)) + + decode = [] + if index > 0: + out_channels = in_channels + else: + out_channels = len(self.sources) * audio_channels + if rewrite: + decode += [ + nn.Conv1d(channels, ch_scale * channels, 2 * context + 1, padding=context), + norm_fn(ch_scale * channels), activation] + if dconv_mode & 2: + decode += [DConv(channels, depth=dconv_depth, init=dconv_init, + compress=dconv_comp, attn=attn, lstm=lstm)] + decode += [nn.ConvTranspose1d(channels, out_channels, + kernel_size, stride, padding=padding)] + if index > 0: + decode += [norm_fn(out_channels), act2()] + self.decoder.insert(0, nn.Sequential(*decode)) + in_channels = channels + channels = int(growth * channels) + + channels = in_channels + if lstm_layers: + self.lstm = BLSTM(channels, lstm_layers) + else: + self.lstm = None + + if rescale: + rescale_module(self, reference=rescale) + + def valid_length(self, length): + """ + Return the nearest valid length to use with the model so that + there is no time steps left over in a convolution, e.g. for all + layers, size of the input - kernel_size % stride = 0. + + Note that input are automatically padded if necessary to ensure that the output + has the same length as the input. + """ + if self.resample: + length *= 2 + + for _ in range(self.depth): + length = math.ceil((length - self.kernel_size) / self.stride) + 1 + length = max(1, length) + + for idx in range(self.depth): + length = (length - 1) * self.stride + self.kernel_size + + if self.resample: + length = math.ceil(length / 2) + return int(length) + + def forward(self, mix): + x = mix + length = x.shape[-1] + + if self.normalize: + mono = mix.mean(dim=1, keepdim=True) + mean = mono.mean(dim=-1, keepdim=True) + std = mono.std(dim=-1, keepdim=True) + x = (x - mean) / (1e-5 + std) + else: + mean = 0 + std = 1 + + delta = self.valid_length(length) - length + x = F.pad(x, (delta // 2, delta - delta // 2)) + + if self.resample: + x = julius.resample_frac(x, 1, 2) + + saved = [] + for encode in self.encoder: + x = encode(x) + saved.append(x) + + if self.lstm: + x = self.lstm(x) + + for decode in self.decoder: + skip = saved.pop(-1) + skip = center_trim(skip, x) + x = decode(x + skip) + + if self.resample: + x = julius.resample_frac(x, 2, 1) + x = x * std + mean + x = center_trim(x, length) + x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1)) + return x + + def load_state_dict(self, state, strict=True): + # fix a mismatch with previous generation Demucs models. + for idx in range(self.depth): + for a in ['encoder', 'decoder']: + for b in ['bias', 'weight']: + new = f'{a}.{idx}.3.{b}' + old = f'{a}.{idx}.2.{b}' + if old in state and new not in state: + state[new] = state.pop(old) + super().load_state_dict(state, strict=strict) diff --git a/AIMeiSheng/demucs/distrib.py b/AIMeiSheng/demucs/distrib.py new file mode 100644 index 0000000..dc1576c --- /dev/null +++ b/AIMeiSheng/demucs/distrib.py @@ -0,0 +1,100 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Distributed training utilities. +""" +import logging +import pickle + +import numpy as np +import torch +from torch.utils.data.distributed import DistributedSampler +from torch.utils.data import DataLoader, Subset +from torch.nn.parallel.distributed import DistributedDataParallel + +from dora import distrib as dora_distrib + +logger = logging.getLogger(__name__) +rank = 0 +world_size = 1 + + +def init(): + global rank, world_size + if not torch.distributed.is_initialized(): + dora_distrib.init() + rank = dora_distrib.rank() + world_size = dora_distrib.world_size() + + +def average(metrics, count=1.): + if isinstance(metrics, dict): + keys, values = zip(*sorted(metrics.items())) + values = average(values, count) + return dict(zip(keys, values)) + if world_size == 1: + return metrics + tensor = torch.tensor(list(metrics) + [1], device='cuda', dtype=torch.float32) + tensor *= count + torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM) + return (tensor[:-1] / tensor[-1]).cpu().numpy().tolist() + + +def wrap(model): + if world_size == 1: + return model + else: + return DistributedDataParallel( + model, + # find_unused_parameters=True, + device_ids=[torch.cuda.current_device()], + output_device=torch.cuda.current_device()) + + +def barrier(): + if world_size > 1: + torch.distributed.barrier() + + +def share(obj=None, src=0): + if world_size == 1: + return obj + size = torch.empty(1, device='cuda', dtype=torch.long) + if rank == src: + dump = pickle.dumps(obj) + size[0] = len(dump) + torch.distributed.broadcast(size, src=src) + # size variable is now set to the length of pickled obj in all processes + + if rank == src: + buffer = torch.from_numpy(np.frombuffer(dump, dtype=np.uint8).copy()).cuda() + else: + buffer = torch.empty(size[0].item(), device='cuda', dtype=torch.uint8) + torch.distributed.broadcast(buffer, src=src) + # buffer variable is now set to pickled obj in all processes + + if rank != src: + obj = pickle.loads(buffer.cpu().numpy().tobytes()) + logger.debug(f"Shared object of size {len(buffer)}") + return obj + + +def loader(dataset, *args, shuffle=False, klass=DataLoader, **kwargs): + """ + Create a dataloader properly in case of distributed training. + If a gradient is going to be computed you must set `shuffle=True`. + """ + if world_size == 1: + return klass(dataset, *args, shuffle=shuffle, **kwargs) + + if shuffle: + # train means we will compute backward, we use DistributedSampler + sampler = DistributedSampler(dataset) + # We ignore shuffle, DistributedSampler already shuffles + return klass(dataset, *args, **kwargs, sampler=sampler) + else: + # We make a manual shard, as DistributedSampler otherwise replicate some examples + dataset = Subset(dataset, list(range(rank, len(dataset), world_size))) + return klass(dataset, *args, shuffle=shuffle, **kwargs) diff --git a/AIMeiSheng/demucs/ema.py b/AIMeiSheng/demucs/ema.py new file mode 100644 index 0000000..101bee0 --- /dev/null +++ b/AIMeiSheng/demucs/ema.py @@ -0,0 +1,66 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Inspired from https://github.com/rwightman/pytorch-image-models +from contextlib import contextmanager + +import torch + +from .states import swap_state + + +class ModelEMA: + """ + Perform EMA on a model. You can switch to the EMA weights temporarily + with the `swap` method. + + ema = ModelEMA(model) + with ema.swap(): + # compute valid metrics with averaged model. + """ + def __init__(self, model, decay=0.9999, unbias=True, device='cpu'): + self.decay = decay + self.model = model + self.state = {} + self.count = 0 + self.device = device + self.unbias = unbias + + self._init() + + def _init(self): + for key, val in self.model.state_dict().items(): + if val.dtype != torch.float32: + continue + device = self.device or val.device + if key not in self.state: + self.state[key] = val.detach().to(device, copy=True) + + def update(self): + if self.unbias: + self.count = self.count * self.decay + 1 + w = 1 / self.count + else: + w = 1 - self.decay + for key, val in self.model.state_dict().items(): + if val.dtype != torch.float32: + continue + device = self.device or val.device + self.state[key].mul_(1 - w) + self.state[key].add_(val.detach().to(device), alpha=w) + + @contextmanager + def swap(self): + with swap_state(self.model, self.state): + yield + + def state_dict(self): + return {'state': self.state, 'count': self.count} + + def load_state_dict(self, state): + self.count = state['count'] + for k, v in state['state'].items(): + self.state[k].copy_(v) diff --git a/AIMeiSheng/demucs/evaluate.py b/AIMeiSheng/demucs/evaluate.py new file mode 100644 index 0000000..fa2ff45 --- /dev/null +++ b/AIMeiSheng/demucs/evaluate.py @@ -0,0 +1,174 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Test time evaluation, either using the original SDR from [Vincent et al. 2006] +or the newest SDR definition from the MDX 2021 competition (this one will +be reported as `nsdr` for `new sdr`). +""" + +from concurrent import futures +import logging + +from dora.log import LogProgress +import numpy as np +import musdb +import museval +import torch as th + +from .apply import apply_model +from .audio import convert_audio, save_audio +from . import distrib +from .utils import DummyPoolExecutor + + +logger = logging.getLogger(__name__) + + +def new_sdr(references, estimates): + """ + Compute the SDR according to the MDX challenge definition. + Adapted from AIcrowd/music-demixing-challenge-starter-kit (MIT license) + """ + assert references.dim() == 4 + assert estimates.dim() == 4 + delta = 1e-7 # avoid numerical errors + num = th.sum(th.square(references), dim=(2, 3)) + den = th.sum(th.square(references - estimates), dim=(2, 3)) + num += delta + den += delta + scores = 10 * th.log10(num / den) + return scores + + +def eval_track(references, estimates, win, hop, compute_sdr=True): + references = references.transpose(1, 2).double() + estimates = estimates.transpose(1, 2).double() + + new_scores = new_sdr(references.cpu()[None], estimates.cpu()[None])[0] + + if not compute_sdr: + return None, new_scores + else: + references = references.numpy() + estimates = estimates.numpy() + scores = museval.metrics.bss_eval( + references, estimates, + compute_permutation=False, + window=win, + hop=hop, + framewise_filters=False, + bsseval_sources_version=False)[:-1] + return scores, new_scores + + +def evaluate(solver, compute_sdr=False): + """ + Evaluate model using museval. + compute_sdr=False means using only the MDX definition of the SDR, which + is much faster to evaluate. + """ + + args = solver.args + + output_dir = solver.folder / "results" + output_dir.mkdir(exist_ok=True, parents=True) + json_folder = solver.folder / "results/test" + json_folder.mkdir(exist_ok=True, parents=True) + + # we load tracks from the original musdb set + if args.test.nonhq is None: + test_set = musdb.DB(args.dset.musdb, subsets=["test"], is_wav=True) + else: + test_set = musdb.DB(args.test.nonhq, subsets=["test"], is_wav=False) + src_rate = args.dset.musdb_samplerate + + eval_device = 'cpu' + + model = solver.model + win = int(1. * model.samplerate) + hop = int(1. * model.samplerate) + + indexes = range(distrib.rank, len(test_set), distrib.world_size) + indexes = LogProgress(logger, indexes, updates=args.misc.num_prints, + name='Eval') + pendings = [] + + pool = futures.ProcessPoolExecutor if args.test.workers else DummyPoolExecutor + with pool(args.test.workers) as pool: + for index in indexes: + track = test_set.tracks[index] + + mix = th.from_numpy(track.audio).t().float() + if mix.dim() == 1: + mix = mix[None] + mix = mix.to(solver.device) + ref = mix.mean(dim=0) # mono mixture + mix = (mix - ref.mean()) / ref.std() + mix = convert_audio(mix, src_rate, model.samplerate, model.audio_channels) + estimates = apply_model(model, mix[None], + shifts=args.test.shifts, split=args.test.split, + overlap=args.test.overlap)[0] + estimates = estimates * ref.std() + ref.mean() + estimates = estimates.to(eval_device) + + references = th.stack( + [th.from_numpy(track.targets[name].audio).t() for name in model.sources]) + if references.dim() == 2: + references = references[:, None] + references = references.to(eval_device) + references = convert_audio(references, src_rate, + model.samplerate, model.audio_channels) + if args.test.save: + folder = solver.folder / "wav" / track.name + folder.mkdir(exist_ok=True, parents=True) + for name, estimate in zip(model.sources, estimates): + save_audio(estimate.cpu(), folder / (name + ".mp3"), model.samplerate) + + pendings.append((track.name, pool.submit( + eval_track, references, estimates, win=win, hop=hop, compute_sdr=compute_sdr))) + + pendings = LogProgress(logger, pendings, updates=args.misc.num_prints, + name='Eval (BSS)') + tracks = {} + for track_name, pending in pendings: + pending = pending.result() + scores, nsdrs = pending + tracks[track_name] = {} + for idx, target in enumerate(model.sources): + tracks[track_name][target] = {'nsdr': [float(nsdrs[idx])]} + if scores is not None: + (sdr, isr, sir, sar) = scores + for idx, target in enumerate(model.sources): + values = { + "SDR": sdr[idx].tolist(), + "SIR": sir[idx].tolist(), + "ISR": isr[idx].tolist(), + "SAR": sar[idx].tolist() + } + tracks[track_name][target].update(values) + + all_tracks = {} + for src in range(distrib.world_size): + all_tracks.update(distrib.share(tracks, src)) + + result = {} + metric_names = next(iter(all_tracks.values()))[model.sources[0]] + for metric_name in metric_names: + avg = 0 + avg_of_medians = 0 + for source in model.sources: + medians = [ + np.nanmedian(all_tracks[track][source][metric_name]) + for track in all_tracks.keys()] + mean = np.mean(medians) + median = np.median(medians) + result[metric_name.lower() + "_" + source] = mean + result[metric_name.lower() + "_med" + "_" + source] = median + avg += mean / len(model.sources) + avg_of_medians += median / len(model.sources) + result[metric_name.lower()] = avg + result[metric_name.lower() + "_med"] = avg_of_medians + return result diff --git a/AIMeiSheng/demucs/grids/__init__.py b/AIMeiSheng/demucs/grids/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/AIMeiSheng/demucs/grids/_explorers.py b/AIMeiSheng/demucs/grids/_explorers.py new file mode 100644 index 0000000..ec3a858 --- /dev/null +++ b/AIMeiSheng/demucs/grids/_explorers.py @@ -0,0 +1,64 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +from dora import Explorer +import treetable as tt + + +class MyExplorer(Explorer): + test_metrics = ['nsdr', 'sdr_med'] + + def get_grid_metrics(self): + """Return the metrics that should be displayed in the tracking table. + """ + return [ + tt.group("train", [ + tt.leaf("epoch"), + tt.leaf("reco", ".3f"), + ], align=">"), + tt.group("valid", [ + tt.leaf("penalty", ".1f"), + tt.leaf("ms", ".1f"), + tt.leaf("reco", ".2%"), + tt.leaf("breco", ".2%"), + tt.leaf("b_nsdr", ".2f"), + # tt.leaf("b_nsdr_drums", ".2f"), + # tt.leaf("b_nsdr_bass", ".2f"), + # tt.leaf("b_nsdr_other", ".2f"), + # tt.leaf("b_nsdr_vocals", ".2f"), + ], align=">"), + tt.group("test", [ + tt.leaf(name, ".2f") + for name in self.test_metrics + ], align=">") + ] + + def process_history(self, history): + train = { + 'epoch': len(history), + } + valid = {} + test = {} + best_v_main = float('inf') + breco = float('inf') + for metrics in history: + train.update(metrics['train']) + valid.update(metrics['valid']) + if 'main' in metrics['valid']: + best_v_main = min(best_v_main, metrics['valid']['main']['loss']) + valid['bmain'] = best_v_main + valid['breco'] = min(breco, metrics['valid']['reco']) + breco = valid['breco'] + if (metrics['valid']['loss'] == metrics['valid']['best'] or + metrics['valid'].get('nsdr') == metrics['valid']['best']): + for k, v in metrics['valid'].items(): + if k.startswith('reco_'): + valid['b_' + k[len('reco_'):]] = v + if k.startswith('nsdr'): + valid[f'b_{k}'] = v + if 'test' in metrics: + test.update(metrics['test']) + metrics = history[-1] + return {"train": train, "valid": valid, "test": test} diff --git a/AIMeiSheng/demucs/grids/mdx.py b/AIMeiSheng/demucs/grids/mdx.py new file mode 100644 index 0000000..62d447f --- /dev/null +++ b/AIMeiSheng/demucs/grids/mdx.py @@ -0,0 +1,33 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +""" +Main training for the Track A MDX models. +""" + +from ._explorers import MyExplorer +from ..train import main + + +TRACK_A = ['0d19c1c6', '7ecf8ec1', 'c511e2ab', '7d865c68'] + + +@MyExplorer +def explorer(launcher): + launcher.slurm_( + gpus=8, + time=3 * 24 * 60, + partition='learnlab') + + # Reproduce results from MDX competition Track A + # This trains the first round of models. Once this is trained, + # you will need to schedule `mdx_refine`. + for sig in TRACK_A: + xp = main.get_xp_from_sig(sig) + parent = xp.cfg.continue_from + xp = main.get_xp_from_sig(parent) + launcher(xp.argv) + launcher(xp.argv, {'quant.diffq': 1e-4}) + launcher(xp.argv, {'quant.diffq': 3e-4}) diff --git a/AIMeiSheng/demucs/grids/mdx_extra.py b/AIMeiSheng/demucs/grids/mdx_extra.py new file mode 100644 index 0000000..b99a37b --- /dev/null +++ b/AIMeiSheng/demucs/grids/mdx_extra.py @@ -0,0 +1,36 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +""" +Main training for the Track A MDX models. +""" + +from ._explorers import MyExplorer +from ..train import main + +TRACK_B = ['e51eebcc', 'a1d90b5c', '5d2d6c55', 'cfa93e08'] + + +@MyExplorer +def explorer(launcher): + launcher.slurm_( + gpus=8, + time=3 * 24 * 60, + partition='learnlab') + + # Reproduce results from MDX competition Track A + # This trains the first round of models. Once this is trained, + # you will need to schedule `mdx_refine`. + for sig in TRACK_B: + while sig is not None: + xp = main.get_xp_from_sig(sig) + sig = xp.cfg.continue_from + + for dset in ['extra44', 'extra_test']: + sub = launcher.bind(xp.argv, dset=dset) + sub() + if dset == 'extra_test': + sub({'quant.diffq': 1e-4}) + sub({'quant.diffq': 3e-4}) diff --git a/AIMeiSheng/demucs/grids/mdx_refine.py b/AIMeiSheng/demucs/grids/mdx_refine.py new file mode 100644 index 0000000..f62da1d --- /dev/null +++ b/AIMeiSheng/demucs/grids/mdx_refine.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +""" +Main training for the Track A MDX models. +""" + +from ._explorers import MyExplorer +from .mdx import TRACK_A +from ..train import main + + +@MyExplorer +def explorer(launcher): + launcher.slurm_( + gpus=8, + time=3 * 24 * 60, + partition='learnlab') + + # Reproduce results from MDX competition Track A + # WARNING: all the experiments in the `mdx` grid must have completed. + for sig in TRACK_A: + xp = main.get_xp_from_sig(sig) + launcher(xp.argv) + for diffq in [1e-4, 3e-4]: + xp_src = main.get_xp_from_sig(xp.cfg.continue_from) + q_argv = [f'quant.diffq={diffq}'] + actual_src = main.get_xp(xp_src.argv + q_argv) + actual_src.link.load() + assert len(actual_src.link.history) == actual_src.cfg.epochs + argv = xp.argv + q_argv + [f'continue_from="{actual_src.sig}"'] + launcher(argv) diff --git a/AIMeiSheng/demucs/grids/mmi.py b/AIMeiSheng/demucs/grids/mmi.py new file mode 100644 index 0000000..d75aa2b --- /dev/null +++ b/AIMeiSheng/demucs/grids/mmi.py @@ -0,0 +1,69 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from ._explorers import MyExplorer +from dora import Launcher + + +@MyExplorer +def explorer(launcher: Launcher): + launcher.slurm_(gpus=8, time=3 * 24 * 60, partition="devlab,learnlab,learnfair") # 3 days + + sub = launcher.bind_( + { + "dset": "extra_mmi_goodclean", + "test.shifts": 0, + "model": "htdemucs", + "htdemucs.dconv_mode": 3, + "htdemucs.depth": 4, + "htdemucs.t_dropout": 0.02, + "htdemucs.t_layers": 5, + "max_batches": 800, + "ema.epoch": [0.9, 0.95], + "ema.batch": [0.9995, 0.9999], + "dset.segment": 10, + "batch_size": 32, + } + ) + sub({"model": "hdemucs"}) + sub({"model": "hdemucs", "dset": "extra44"}) + sub({"model": "hdemucs", "dset": "musdb44"}) + + sparse = { + 'batch_size': 3 * 8, + 'augment.remix.group_size': 3, + 'htdemucs.t_auto_sparsity': True, + 'htdemucs.t_sparse_self_attn': True, + 'htdemucs.t_sparse_cross_attn': True, + 'htdemucs.t_sparsity': 0.9, + "htdemucs.t_layers": 7 + } + + with launcher.job_array(): + for transf_layers in [5, 7]: + for bottom_channels in [0, 512]: + sub = launcher.bind({ + "htdemucs.t_layers": transf_layers, + "htdemucs.bottom_channels": bottom_channels, + }) + if bottom_channels == 0 and transf_layers == 5: + sub({"augment.remix.proba": 0.0}) + sub({ + "augment.repitch.proba": 0.0, + # when doing repitching, we trim the outut to align on the + # highest change of BPM. When removing repitching, + # we simulate it here to ensure the training context is the same. + # Another second is lost for all experiments due to the random + # shift augmentation. + "dset.segment": 10 * 0.88}) + elif bottom_channels == 512 and transf_layers == 5: + sub(dset="musdb44") + sub(dset="extra44") + # Sparse kernel XP, currently not released as kernels are still experimental. + sub(sparse, {'dset.segment': 15, "htdemucs.t_layers": 7}) + + for duration in [5, 10, 15]: + sub({"dset.segment": duration}) diff --git a/AIMeiSheng/demucs/grids/mmi_ft.py b/AIMeiSheng/demucs/grids/mmi_ft.py new file mode 100644 index 0000000..73e488b --- /dev/null +++ b/AIMeiSheng/demucs/grids/mmi_ft.py @@ -0,0 +1,55 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from ._explorers import MyExplorer +from dora import Launcher +from demucs import train + + +def get_sub(launcher, sig): + xp = train.main.get_xp_from_sig(sig) + sub = launcher.bind(xp.argv) + sub() + sub.bind_({ + 'continue_from': sig, + 'continue_best': True}) + return sub + + +@MyExplorer +def explorer(launcher: Launcher): + launcher.slurm_(gpus=4, time=3 * 24 * 60, partition="devlab,learnlab,learnfair") # 3 days + ft = { + 'optim.lr': 1e-4, + 'augment.remix.proba': 0, + 'augment.scale.proba': 0, + 'augment.shift_same': True, + 'htdemucs.t_weight_decay': 0.05, + 'batch_size': 8, + 'optim.clip_grad': 5, + 'optim.optim': 'adamw', + 'epochs': 50, + 'dset.wav2_valid': True, + 'ema.epoch': [], # let's make valid a bit faster + } + with launcher.job_array(): + for sig in ['2899e11a']: + sub = get_sub(launcher, sig) + sub.bind_(ft) + for segment in [15, 18]: + for source in range(4): + w = [0] * 4 + w[source] = 1 + sub({'weights': w, 'dset.segment': segment}) + + for sig in ['955717e8']: + sub = get_sub(launcher, sig) + sub.bind_(ft) + for segment in [10, 15]: + for source in range(4): + w = [0] * 4 + w[source] = 1 + sub({'weights': w, 'dset.segment': segment}) diff --git a/AIMeiSheng/demucs/grids/repro.py b/AIMeiSheng/demucs/grids/repro.py new file mode 100644 index 0000000..21d33fc --- /dev/null +++ b/AIMeiSheng/demucs/grids/repro.py @@ -0,0 +1,50 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +""" +Easier training for reproducibility +""" + +from ._explorers import MyExplorer + + +@MyExplorer +def explorer(launcher): + launcher.slurm_( + gpus=8, + time=3 * 24 * 60, + partition='devlab,learnlab') + + launcher.bind_({'ema.epoch': [0.9, 0.95]}) + launcher.bind_({'ema.batch': [0.9995, 0.9999]}) + launcher.bind_({'epochs': 600}) + + base = {'model': 'demucs', 'demucs.dconv_mode': 0, 'demucs.gelu': False, + 'demucs.lstm_layers': 2} + newt = {'model': 'demucs', 'demucs.normalize': True} + hdem = {'model': 'hdemucs'} + svd = {'svd.penalty': 1e-5, 'svd': 'base2'} + + with launcher.job_array(): + for model in [base, newt, hdem]: + sub = launcher.bind(model) + if model is base: + # Training the v2 Demucs on MusDB HQ + sub(epochs=360) + continue + + # those two will be used in the repro_mdx_a bag of models. + sub(svd) + sub(svd, seed=43) + if model == newt: + # Ablation study + sub() + abl = sub.bind(svd) + abl({'ema.epoch': [], 'ema.batch': []}) + abl({'demucs.dconv_lstm': 10}) + abl({'demucs.dconv_attn': 10}) + abl({'demucs.dconv_attn': 10, 'demucs.dconv_lstm': 10, 'demucs.lstm_layers': 2}) + abl({'demucs.dconv_mode': 0}) + abl({'demucs.gelu': False}) diff --git a/AIMeiSheng/demucs/grids/repro_ft.py b/AIMeiSheng/demucs/grids/repro_ft.py new file mode 100644 index 0000000..7bb4ee8 --- /dev/null +++ b/AIMeiSheng/demucs/grids/repro_ft.py @@ -0,0 +1,46 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +""" +Fine tuning experiments +""" + +from ._explorers import MyExplorer +from ..train import main + + +@MyExplorer +def explorer(launcher): + launcher.slurm_( + gpus=8, + time=300, + partition='devlab,learnlab') + + # Mus + launcher.slurm_(constraint='volta32gb') + + grid = "repro" + folder = main.dora.dir / "grids" / grid + + for sig in folder.iterdir(): + if not sig.is_symlink(): + continue + xp = main.get_xp_from_sig(sig) + xp.link.load() + if len(xp.link.history) != xp.cfg.epochs: + continue + sub = launcher.bind(xp.argv, [f'continue_from="{xp.sig}"']) + sub.bind_({'ema.epoch': [0.9, 0.95], 'ema.batch': [0.9995, 0.9999]}) + sub.bind_({'test.every': 1, 'test.sdr': True, 'epochs': 4}) + sub.bind_({'dset.segment': 28, 'dset.shift': 2}) + sub.bind_({'batch_size': 32}) + auto = {'dset': 'auto_mus'} + auto.update({'augment.remix.proba': 0, 'augment.scale.proba': 0, + 'augment.shift_same': True}) + sub.bind_(auto) + sub.bind_({'batch_size': 16}) + sub.bind_({'optim.lr': 1e-4}) + sub.bind_({'model_segment': 44}) + sub() diff --git a/AIMeiSheng/demucs/grids/sdx23.py b/AIMeiSheng/demucs/grids/sdx23.py new file mode 100644 index 0000000..3bdb419 --- /dev/null +++ b/AIMeiSheng/demucs/grids/sdx23.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from ._explorers import MyExplorer +from dora import Launcher + + +@MyExplorer +def explorer(launcher: Launcher): + launcher.slurm_(gpus=8, time=3 * 24 * 60, partition="speechgpt,learnfair", + mem_per_gpu=None, constraint='') + launcher.bind_({"dset.use_musdb": False}) + + with launcher.job_array(): + launcher(dset='sdx23_bleeding') + launcher(dset='sdx23_labelnoise') diff --git a/AIMeiSheng/demucs/hdemucs.py b/AIMeiSheng/demucs/hdemucs.py new file mode 100644 index 0000000..711d471 --- /dev/null +++ b/AIMeiSheng/demucs/hdemucs.py @@ -0,0 +1,794 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +""" +This code contains the spectrogram and Hybrid version of Demucs. +""" +from copy import deepcopy +import math +import typing as tp + +from openunmix.filtering import wiener +import torch +from torch import nn +from torch.nn import functional as F + +from .demucs import DConv, rescale_module +from .states import capture_init +from .spec import spectro, ispectro + + +def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.): + """Tiny wrapper around F.pad, just to allow for reflect padding on small input. + If this is the case, we insert extra 0 padding to the right before the reflection happen.""" + x0 = x + length = x.shape[-1] + padding_left, padding_right = paddings + if mode == 'reflect': + max_pad = max(padding_left, padding_right) + if length <= max_pad: + extra_pad = max_pad - length + 1 + extra_pad_right = min(padding_right, extra_pad) + extra_pad_left = extra_pad - extra_pad_right + paddings = (padding_left - extra_pad_left, padding_right - extra_pad_right) + x = F.pad(x, (extra_pad_left, extra_pad_right)) + out = F.pad(x, paddings, mode, value) + assert out.shape[-1] == length + padding_left + padding_right + assert (out[..., padding_left: padding_left + length] == x0).all() + return out + + +class ScaledEmbedding(nn.Module): + """ + Boost learning rate for embeddings (with `scale`). + Also, can make embeddings continuous with `smooth`. + """ + def __init__(self, num_embeddings: int, embedding_dim: int, + scale: float = 10., smooth=False): + super().__init__() + self.embedding = nn.Embedding(num_embeddings, embedding_dim) + if smooth: + weight = torch.cumsum(self.embedding.weight.data, dim=0) + # when summing gaussian, overscale raises as sqrt(n), so we nornalize by that. + weight = weight / torch.arange(1, num_embeddings + 1).to(weight).sqrt()[:, None] + self.embedding.weight.data[:] = weight + self.embedding.weight.data /= scale + self.scale = scale + + @property + def weight(self): + return self.embedding.weight * self.scale + + def forward(self, x): + out = self.embedding(x) * self.scale + return out + + +class HEncLayer(nn.Module): + def __init__(self, chin, chout, kernel_size=8, stride=4, norm_groups=1, empty=False, + freq=True, dconv=True, norm=True, context=0, dconv_kw={}, pad=True, + rewrite=True): + """Encoder layer. This used both by the time and the frequency branch. + + Args: + chin: number of input channels. + chout: number of output channels. + norm_groups: number of groups for group norm. + empty: used to make a layer with just the first conv. this is used + before merging the time and freq. branches. + freq: this is acting on frequencies. + dconv: insert DConv residual branches. + norm: use GroupNorm. + context: context size for the 1x1 conv. + dconv_kw: list of kwargs for the DConv class. + pad: pad the input. Padding is done so that the output size is + always the input size / stride. + rewrite: add 1x1 conv at the end of the layer. + """ + super().__init__() + norm_fn = lambda d: nn.Identity() # noqa + if norm: + norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa + if pad: + pad = kernel_size // 4 + else: + pad = 0 + klass = nn.Conv1d + self.freq = freq + self.kernel_size = kernel_size + self.stride = stride + self.empty = empty + self.norm = norm + self.pad = pad + if freq: + kernel_size = [kernel_size, 1] + stride = [stride, 1] + pad = [pad, 0] + klass = nn.Conv2d + self.conv = klass(chin, chout, kernel_size, stride, pad) + if self.empty: + return + self.norm1 = norm_fn(chout) + self.rewrite = None + if rewrite: + self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context) + self.norm2 = norm_fn(2 * chout) + + self.dconv = None + if dconv: + self.dconv = DConv(chout, **dconv_kw) + + def forward(self, x, inject=None): + """ + `inject` is used to inject the result from the time branch into the frequency branch, + when both have the same stride. + """ + if not self.freq and x.dim() == 4: + B, C, Fr, T = x.shape + x = x.view(B, -1, T) + + if not self.freq: + le = x.shape[-1] + if not le % self.stride == 0: + x = F.pad(x, (0, self.stride - (le % self.stride))) + y = self.conv(x) + if self.empty: + return y + if inject is not None: + assert inject.shape[-1] == y.shape[-1], (inject.shape, y.shape) + if inject.dim() == 3 and y.dim() == 4: + inject = inject[:, :, None] + y = y + inject + y = F.gelu(self.norm1(y)) + if self.dconv: + if self.freq: + B, C, Fr, T = y.shape + y = y.permute(0, 2, 1, 3).reshape(-1, C, T) + y = self.dconv(y) + if self.freq: + y = y.view(B, Fr, C, T).permute(0, 2, 1, 3) + if self.rewrite: + z = self.norm2(self.rewrite(y)) + z = F.glu(z, dim=1) + else: + z = y + return z + + +class MultiWrap(nn.Module): + """ + Takes one layer and replicate it N times. each replica will act + on a frequency band. All is done so that if the N replica have the same weights, + then this is exactly equivalent to applying the original module on all frequencies. + + This is a bit over-engineered to avoid edge artifacts when splitting + the frequency bands, but it is possible the naive implementation would work as well... + """ + def __init__(self, layer, split_ratios): + """ + Args: + layer: module to clone, must be either HEncLayer or HDecLayer. + split_ratios: list of float indicating which ratio to keep for each band. + """ + super().__init__() + self.split_ratios = split_ratios + self.layers = nn.ModuleList() + self.conv = isinstance(layer, HEncLayer) + assert not layer.norm + assert layer.freq + assert layer.pad + if not self.conv: + assert not layer.context_freq + for k in range(len(split_ratios) + 1): + lay = deepcopy(layer) + if self.conv: + lay.conv.padding = (0, 0) + else: + lay.pad = False + for m in lay.modules(): + if hasattr(m, 'reset_parameters'): + m.reset_parameters() + self.layers.append(lay) + + def forward(self, x, skip=None, length=None): + B, C, Fr, T = x.shape + + ratios = list(self.split_ratios) + [1] + start = 0 + outs = [] + for ratio, layer in zip(ratios, self.layers): + if self.conv: + pad = layer.kernel_size // 4 + if ratio == 1: + limit = Fr + frames = -1 + else: + limit = int(round(Fr * ratio)) + le = limit - start + if start == 0: + le += pad + frames = round((le - layer.kernel_size) / layer.stride + 1) + limit = start + (frames - 1) * layer.stride + layer.kernel_size + if start == 0: + limit -= pad + assert limit - start > 0, (limit, start) + assert limit <= Fr, (limit, Fr) + y = x[:, :, start:limit, :] + if start == 0: + y = F.pad(y, (0, 0, pad, 0)) + if ratio == 1: + y = F.pad(y, (0, 0, 0, pad)) + outs.append(layer(y)) + start = limit - layer.kernel_size + layer.stride + else: + if ratio == 1: + limit = Fr + else: + limit = int(round(Fr * ratio)) + last = layer.last + layer.last = True + + y = x[:, :, start:limit] + s = skip[:, :, start:limit] + out, _ = layer(y, s, None) + if outs: + outs[-1][:, :, -layer.stride:] += ( + out[:, :, :layer.stride] - layer.conv_tr.bias.view(1, -1, 1, 1)) + out = out[:, :, layer.stride:] + if ratio == 1: + out = out[:, :, :-layer.stride // 2, :] + if start == 0: + out = out[:, :, layer.stride // 2:, :] + outs.append(out) + layer.last = last + start = limit + out = torch.cat(outs, dim=2) + if not self.conv and not last: + out = F.gelu(out) + if self.conv: + return out + else: + return out, None + + +class HDecLayer(nn.Module): + def __init__(self, chin, chout, last=False, kernel_size=8, stride=4, norm_groups=1, empty=False, + freq=True, dconv=True, norm=True, context=1, dconv_kw={}, pad=True, + context_freq=True, rewrite=True): + """ + Same as HEncLayer but for decoder. See `HEncLayer` for documentation. + """ + super().__init__() + norm_fn = lambda d: nn.Identity() # noqa + if norm: + norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa + if pad: + pad = kernel_size // 4 + else: + pad = 0 + self.pad = pad + self.last = last + self.freq = freq + self.chin = chin + self.empty = empty + self.stride = stride + self.kernel_size = kernel_size + self.norm = norm + self.context_freq = context_freq + klass = nn.Conv1d + klass_tr = nn.ConvTranspose1d + if freq: + kernel_size = [kernel_size, 1] + stride = [stride, 1] + klass = nn.Conv2d + klass_tr = nn.ConvTranspose2d + self.conv_tr = klass_tr(chin, chout, kernel_size, stride) + self.norm2 = norm_fn(chout) + if self.empty: + return + self.rewrite = None + if rewrite: + if context_freq: + self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context) + else: + self.rewrite = klass(chin, 2 * chin, [1, 1 + 2 * context], 1, + [0, context]) + self.norm1 = norm_fn(2 * chin) + + self.dconv = None + if dconv: + self.dconv = DConv(chin, **dconv_kw) + + def forward(self, x, skip, length): + if self.freq and x.dim() == 3: + B, C, T = x.shape + x = x.view(B, self.chin, -1, T) + + if not self.empty: + x = x + skip + + if self.rewrite: + y = F.glu(self.norm1(self.rewrite(x)), dim=1) + else: + y = x + if self.dconv: + if self.freq: + B, C, Fr, T = y.shape + y = y.permute(0, 2, 1, 3).reshape(-1, C, T) + y = self.dconv(y) + if self.freq: + y = y.view(B, Fr, C, T).permute(0, 2, 1, 3) + else: + y = x + assert skip is None + z = self.norm2(self.conv_tr(y)) + if self.freq: + if self.pad: + z = z[..., self.pad:-self.pad, :] + else: + z = z[..., self.pad:self.pad + length] + assert z.shape[-1] == length, (z.shape[-1], length) + if not self.last: + z = F.gelu(z) + return z, y + + +class HDemucs(nn.Module): + """ + Spectrogram and hybrid Demucs model. + The spectrogram model has the same structure as Demucs, except the first few layers are over the + frequency axis, until there is only 1 frequency, and then it moves to time convolutions. + Frequency layers can still access information across time steps thanks to the DConv residual. + + Hybrid model have a parallel time branch. At some layer, the time branch has the same stride + as the frequency branch and then the two are combined. The opposite happens in the decoder. + + Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]), + or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on + Open Unmix implementation [Stoter et al. 2019]. + + The loss is always on the temporal domain, by backpropagating through the above + output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks + a bit Wiener filtering, as doing more iteration at test time will change the spectrogram + contribution, without changing the one from the waveform, which will lead to worse performance. + I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve. + CaC on the other hand provides similar performance for hybrid, and works naturally with + hybrid models. + + This model also uses frequency embeddings are used to improve efficiency on convolutions + over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf). + + Unlike classic Demucs, there is no resampling here, and normalization is always applied. + """ + @capture_init + def __init__(self, + sources, + # Channels + audio_channels=2, + channels=48, + channels_time=None, + growth=2, + # STFT + nfft=4096, + wiener_iters=0, + end_iters=0, + wiener_residual=False, + cac=True, + # Main structure + depth=6, + rewrite=True, + hybrid=True, + hybrid_old=False, + # Frequency branch + multi_freqs=None, + multi_freqs_depth=2, + freq_emb=0.2, + emb_scale=10, + emb_smooth=True, + # Convolutions + kernel_size=8, + time_stride=2, + stride=4, + context=1, + context_enc=0, + # Normalization + norm_starts=4, + norm_groups=4, + # DConv residual branch + dconv_mode=1, + dconv_depth=2, + dconv_comp=4, + dconv_attn=4, + dconv_lstm=4, + dconv_init=1e-4, + # Weight init + rescale=0.1, + # Metadata + samplerate=44100, + segment=4 * 10): + """ + Args: + sources (list[str]): list of source names. + audio_channels (int): input/output audio channels. + channels (int): initial number of hidden channels. + channels_time: if not None, use a different `channels` value for the time branch. + growth: increase the number of hidden channels by this factor at each layer. + nfft: number of fft bins. Note that changing this require careful computation of + various shape parameters and will not work out of the box for hybrid models. + wiener_iters: when using Wiener filtering, number of iterations at test time. + end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`. + wiener_residual: add residual source before wiener filtering. + cac: uses complex as channels, i.e. complex numbers are 2 channels each + in input and output. no further processing is done before ISTFT. + depth (int): number of layers in the encoder and in the decoder. + rewrite (bool): add 1x1 convolution to each layer. + hybrid (bool): make a hybrid time/frequency domain, otherwise frequency only. + hybrid_old: some models trained for MDX had a padding bug. This replicates + this bug to avoid retraining them. + multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`. + multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost + layers will be wrapped. + freq_emb: add frequency embedding after the first frequency layer if > 0, + the actual value controls the weight of the embedding. + emb_scale: equivalent to scaling the embedding learning rate + emb_smooth: initialize the embedding with a smooth one (with respect to frequencies). + kernel_size: kernel_size for encoder and decoder layers. + stride: stride for encoder and decoder layers. + time_stride: stride for the final time layer, after the merge. + context: context for 1x1 conv in the decoder. + context_enc: context for 1x1 conv in the encoder. + norm_starts: layer at which group norm starts being used. + decoder layers are numbered in reverse order. + norm_groups: number of groups for group norm. + dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both. + dconv_depth: depth of residual DConv branch. + dconv_comp: compression of DConv branch. + dconv_attn: adds attention layers in DConv branch starting at this layer. + dconv_lstm: adds a LSTM layer in DConv branch starting at this layer. + dconv_init: initial scale for the DConv branch LayerScale. + rescale: weight recaling trick + + """ + super().__init__() + self.cac = cac + self.wiener_residual = wiener_residual + self.audio_channels = audio_channels + self.sources = sources + self.kernel_size = kernel_size + self.context = context + self.stride = stride + self.depth = depth + self.channels = channels + self.samplerate = samplerate + self.segment = segment + + self.nfft = nfft + self.hop_length = nfft // 4 + self.wiener_iters = wiener_iters + self.end_iters = end_iters + self.freq_emb = None + self.hybrid = hybrid + self.hybrid_old = hybrid_old + if hybrid_old: + assert hybrid, "hybrid_old must come with hybrid=True" + if hybrid: + assert wiener_iters == end_iters + + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + + if hybrid: + self.tencoder = nn.ModuleList() + self.tdecoder = nn.ModuleList() + + chin = audio_channels + chin_z = chin # number of channels for the freq branch + if self.cac: + chin_z *= 2 + chout = channels_time or channels + chout_z = channels + freqs = nfft // 2 + + for index in range(depth): + lstm = index >= dconv_lstm + attn = index >= dconv_attn + norm = index >= norm_starts + freq = freqs > 1 + stri = stride + ker = kernel_size + if not freq: + assert freqs == 1 + ker = time_stride * 2 + stri = time_stride + + pad = True + last_freq = False + if freq and freqs <= kernel_size: + ker = freqs + pad = False + last_freq = True + + kw = { + 'kernel_size': ker, + 'stride': stri, + 'freq': freq, + 'pad': pad, + 'norm': norm, + 'rewrite': rewrite, + 'norm_groups': norm_groups, + 'dconv_kw': { + 'lstm': lstm, + 'attn': attn, + 'depth': dconv_depth, + 'compress': dconv_comp, + 'init': dconv_init, + 'gelu': True, + } + } + kwt = dict(kw) + kwt['freq'] = 0 + kwt['kernel_size'] = kernel_size + kwt['stride'] = stride + kwt['pad'] = True + kw_dec = dict(kw) + multi = False + if multi_freqs and index < multi_freqs_depth: + multi = True + kw_dec['context_freq'] = False + + if last_freq: + chout_z = max(chout, chout_z) + chout = chout_z + + enc = HEncLayer(chin_z, chout_z, + dconv=dconv_mode & 1, context=context_enc, **kw) + if hybrid and freq: + tenc = HEncLayer(chin, chout, dconv=dconv_mode & 1, context=context_enc, + empty=last_freq, **kwt) + self.tencoder.append(tenc) + + if multi: + enc = MultiWrap(enc, multi_freqs) + self.encoder.append(enc) + if index == 0: + chin = self.audio_channels * len(self.sources) + chin_z = chin + if self.cac: + chin_z *= 2 + dec = HDecLayer(chout_z, chin_z, dconv=dconv_mode & 2, + last=index == 0, context=context, **kw_dec) + if multi: + dec = MultiWrap(dec, multi_freqs) + if hybrid and freq: + tdec = HDecLayer(chout, chin, dconv=dconv_mode & 2, empty=last_freq, + last=index == 0, context=context, **kwt) + self.tdecoder.insert(0, tdec) + self.decoder.insert(0, dec) + + chin = chout + chin_z = chout_z + chout = int(growth * chout) + chout_z = int(growth * chout_z) + if freq: + if freqs <= kernel_size: + freqs = 1 + else: + freqs //= stride + if index == 0 and freq_emb: + self.freq_emb = ScaledEmbedding( + freqs, chin_z, smooth=emb_smooth, scale=emb_scale) + self.freq_emb_scale = freq_emb + + if rescale: + rescale_module(self, reference=rescale) + + def _spec(self, x): + hl = self.hop_length + nfft = self.nfft + x0 = x # noqa + + if self.hybrid: + # We re-pad the signal in order to keep the property + # that the size of the output is exactly the size of the input + # divided by the stride (here hop_length), when divisible. + # This is achieved by padding by 1/4th of the kernel size (here nfft). + # which is not supported by torch.stft. + # Having all convolution operations follow this convention allow to easily + # align the time and frequency branches later on. + assert hl == nfft // 4 + le = int(math.ceil(x.shape[-1] / hl)) + pad = hl // 2 * 3 + if not self.hybrid_old: + x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode='reflect') + else: + x = pad1d(x, (pad, pad + le * hl - x.shape[-1])) + + z = spectro(x, nfft, hl)[..., :-1, :] + if self.hybrid: + assert z.shape[-1] == le + 4, (z.shape, x.shape, le) + z = z[..., 2:2+le] + return z + + def _ispec(self, z, length=None, scale=0): + hl = self.hop_length // (4 ** scale) + z = F.pad(z, (0, 0, 0, 1)) + if self.hybrid: + z = F.pad(z, (2, 2)) + pad = hl // 2 * 3 + if not self.hybrid_old: + le = hl * int(math.ceil(length / hl)) + 2 * pad + else: + le = hl * int(math.ceil(length / hl)) + x = ispectro(z, hl, length=le) + if not self.hybrid_old: + x = x[..., pad:pad + length] + else: + x = x[..., :length] + else: + x = ispectro(z, hl, length) + return x + + def _magnitude(self, z): + # return the magnitude of the spectrogram, except when cac is True, + # in which case we just move the complex dimension to the channel one. + if self.cac: + B, C, Fr, T = z.shape + m = torch.view_as_real(z).permute(0, 1, 4, 2, 3) + m = m.reshape(B, C * 2, Fr, T) + else: + m = z.abs() + return m + + def _mask(self, z, m): + # Apply masking given the mixture spectrogram `z` and the estimated mask `m`. + # If `cac` is True, `m` is actually a full spectrogram and `z` is ignored. + niters = self.wiener_iters + if self.cac: + B, S, C, Fr, T = m.shape + out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3) + out = torch.view_as_complex(out.contiguous()) + return out + if self.training: + niters = self.end_iters + if niters < 0: + z = z[:, None] + return z / (1e-8 + z.abs()) * m + else: + return self._wiener(m, z, niters) + + def _wiener(self, mag_out, mix_stft, niters): + # apply wiener filtering from OpenUnmix. + init = mix_stft.dtype + wiener_win_len = 300 + residual = self.wiener_residual + + B, S, C, Fq, T = mag_out.shape + mag_out = mag_out.permute(0, 4, 3, 2, 1) + mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1)) + + outs = [] + for sample in range(B): + pos = 0 + out = [] + for pos in range(0, T, wiener_win_len): + frame = slice(pos, pos + wiener_win_len) + z_out = wiener( + mag_out[sample, frame], mix_stft[sample, frame], niters, + residual=residual) + out.append(z_out.transpose(-1, -2)) + outs.append(torch.cat(out, dim=0)) + out = torch.view_as_complex(torch.stack(outs, 0)) + out = out.permute(0, 4, 3, 2, 1).contiguous() + if residual: + out = out[:, :-1] + assert list(out.shape) == [B, S, C, Fq, T] + return out.to(init) + + def forward(self, mix): + x = mix + length = x.shape[-1] + + z = self._spec(mix) + mag = self._magnitude(z).to(mix.device) + x = mag + + B, C, Fq, T = x.shape + + # unlike previous Demucs, we always normalize because it is easier. + mean = x.mean(dim=(1, 2, 3), keepdim=True) + std = x.std(dim=(1, 2, 3), keepdim=True) + x = (x - mean) / (1e-5 + std) + # x will be the freq. branch input. + + if self.hybrid: + # Prepare the time branch input. + xt = mix + meant = xt.mean(dim=(1, 2), keepdim=True) + stdt = xt.std(dim=(1, 2), keepdim=True) + xt = (xt - meant) / (1e-5 + stdt) + + # okay, this is a giant mess I know... + saved = [] # skip connections, freq. + saved_t = [] # skip connections, time. + lengths = [] # saved lengths to properly remove padding, freq branch. + lengths_t = [] # saved lengths for time branch. + for idx, encode in enumerate(self.encoder): + lengths.append(x.shape[-1]) + inject = None + if self.hybrid and idx < len(self.tencoder): + # we have not yet merged branches. + lengths_t.append(xt.shape[-1]) + tenc = self.tencoder[idx] + xt = tenc(xt) + if not tenc.empty: + # save for skip connection + saved_t.append(xt) + else: + # tenc contains just the first conv., so that now time and freq. + # branches have the same shape and can be merged. + inject = xt + x = encode(x, inject) + if idx == 0 and self.freq_emb is not None: + # add frequency embedding to allow for non equivariant convolutions + # over the frequency axis. + frs = torch.arange(x.shape[-2], device=x.device) + emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x) + x = x + self.freq_emb_scale * emb + + saved.append(x) + + x = torch.zeros_like(x) + if self.hybrid: + xt = torch.zeros_like(x) + # initialize everything to zero (signal will go through u-net skips). + + for idx, decode in enumerate(self.decoder): + skip = saved.pop(-1) + x, pre = decode(x, skip, lengths.pop(-1)) + # `pre` contains the output just before final transposed convolution, + # which is used when the freq. and time branch separate. + + if self.hybrid: + offset = self.depth - len(self.tdecoder) + if self.hybrid and idx >= offset: + tdec = self.tdecoder[idx - offset] + length_t = lengths_t.pop(-1) + if tdec.empty: + assert pre.shape[2] == 1, pre.shape + pre = pre[:, :, 0] + xt, _ = tdec(pre, None, length_t) + else: + skip = saved_t.pop(-1) + xt, _ = tdec(xt, skip, length_t) + + # Let's make sure we used all stored skip connections. + assert len(saved) == 0 + assert len(lengths_t) == 0 + assert len(saved_t) == 0 + + S = len(self.sources) + x = x.view(B, S, -1, Fq, T) + x = x * std[:, None] + mean[:, None] + + # to cpu as mps doesnt support complex numbers + # demucs issue #435 ##432 + # NOTE: in this case z already is on cpu + # TODO: remove this when mps supports complex numbers + x_is_mps = x.device.type == "mps" + if x_is_mps: + x = x.cpu() + + zout = self._mask(z, x) + x = self._ispec(zout, length) + + # back to mps device + if x_is_mps: + x = x.to('mps') + + if self.hybrid: + xt = xt.view(B, S, -1, length) + xt = xt * stdt[:, None] + meant[:, None] + x = xt + x + return x diff --git a/AIMeiSheng/demucs/htdemucs.py b/AIMeiSheng/demucs/htdemucs.py new file mode 100644 index 0000000..5d2eaaa --- /dev/null +++ b/AIMeiSheng/demucs/htdemucs.py @@ -0,0 +1,660 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# First author is Simon Rouard. +""" +This code contains the spectrogram and Hybrid version of Demucs. +""" +import math + +from openunmix.filtering import wiener +import torch +from torch import nn +from torch.nn import functional as F +from fractions import Fraction +from einops import rearrange + +from .transformer import CrossTransformerEncoder + +from .demucs import rescale_module +from .states import capture_init +from .spec import spectro, ispectro +from .hdemucs import pad1d, ScaledEmbedding, HEncLayer, MultiWrap, HDecLayer + + +class HTDemucs(nn.Module): + """ + Spectrogram and hybrid Demucs model. + The spectrogram model has the same structure as Demucs, except the first few layers are over the + frequency axis, until there is only 1 frequency, and then it moves to time convolutions. + Frequency layers can still access information across time steps thanks to the DConv residual. + + Hybrid model have a parallel time branch. At some layer, the time branch has the same stride + as the frequency branch and then the two are combined. The opposite happens in the decoder. + + Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]), + or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on + Open Unmix implementation [Stoter et al. 2019]. + + The loss is always on the temporal domain, by backpropagating through the above + output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks + a bit Wiener filtering, as doing more iteration at test time will change the spectrogram + contribution, without changing the one from the waveform, which will lead to worse performance. + I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve. + CaC on the other hand provides similar performance for hybrid, and works naturally with + hybrid models. + + This model also uses frequency embeddings are used to improve efficiency on convolutions + over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf). + + Unlike classic Demucs, there is no resampling here, and normalization is always applied. + """ + + @capture_init + def __init__( + self, + sources, + # Channels + audio_channels=2, + channels=48, + channels_time=None, + growth=2, + # STFT + nfft=4096, + wiener_iters=0, + end_iters=0, + wiener_residual=False, + cac=True, + # Main structure + depth=4, + rewrite=True, + # Frequency branch + multi_freqs=None, + multi_freqs_depth=3, + freq_emb=0.2, + emb_scale=10, + emb_smooth=True, + # Convolutions + kernel_size=8, + time_stride=2, + stride=4, + context=1, + context_enc=0, + # Normalization + norm_starts=4, + norm_groups=4, + # DConv residual branch + dconv_mode=1, + dconv_depth=2, + dconv_comp=8, + dconv_init=1e-3, + # Before the Transformer + bottom_channels=0, + # Transformer + t_layers=5, + t_emb="sin", + t_hidden_scale=4.0, + t_heads=8, + t_dropout=0.0, + t_max_positions=10000, + t_norm_in=True, + t_norm_in_group=False, + t_group_norm=False, + t_norm_first=True, + t_norm_out=True, + t_max_period=10000.0, + t_weight_decay=0.0, + t_lr=None, + t_layer_scale=True, + t_gelu=True, + t_weight_pos_embed=1.0, + t_sin_random_shift=0, + t_cape_mean_normalize=True, + t_cape_augment=True, + t_cape_glob_loc_scale=[5000.0, 1.0, 1.4], + t_sparse_self_attn=False, + t_sparse_cross_attn=False, + t_mask_type="diag", + t_mask_random_seed=42, + t_sparse_attn_window=500, + t_global_window=100, + t_sparsity=0.95, + t_auto_sparsity=False, + # ------ Particuliar parameters + t_cross_first=False, + # Weight init + rescale=0.1, + # Metadata + samplerate=44100, + segment=10, + use_train_segment=True, + ): + """ + Args: + sources (list[str]): list of source names. + audio_channels (int): input/output audio channels. + channels (int): initial number of hidden channels. + channels_time: if not None, use a different `channels` value for the time branch. + growth: increase the number of hidden channels by this factor at each layer. + nfft: number of fft bins. Note that changing this require careful computation of + various shape parameters and will not work out of the box for hybrid models. + wiener_iters: when using Wiener filtering, number of iterations at test time. + end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`. + wiener_residual: add residual source before wiener filtering. + cac: uses complex as channels, i.e. complex numbers are 2 channels each + in input and output. no further processing is done before ISTFT. + depth (int): number of layers in the encoder and in the decoder. + rewrite (bool): add 1x1 convolution to each layer. + multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`. + multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost + layers will be wrapped. + freq_emb: add frequency embedding after the first frequency layer if > 0, + the actual value controls the weight of the embedding. + emb_scale: equivalent to scaling the embedding learning rate + emb_smooth: initialize the embedding with a smooth one (with respect to frequencies). + kernel_size: kernel_size for encoder and decoder layers. + stride: stride for encoder and decoder layers. + time_stride: stride for the final time layer, after the merge. + context: context for 1x1 conv in the decoder. + context_enc: context for 1x1 conv in the encoder. + norm_starts: layer at which group norm starts being used. + decoder layers are numbered in reverse order. + norm_groups: number of groups for group norm. + dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both. + dconv_depth: depth of residual DConv branch. + dconv_comp: compression of DConv branch. + dconv_attn: adds attention layers in DConv branch starting at this layer. + dconv_lstm: adds a LSTM layer in DConv branch starting at this layer. + dconv_init: initial scale for the DConv branch LayerScale. + bottom_channels: if >0 it adds a linear layer (1x1 Conv) before and after the + transformer in order to change the number of channels + t_layers: number of layers in each branch (waveform and spec) of the transformer + t_emb: "sin", "cape" or "scaled" + t_hidden_scale: the hidden scale of the Feedforward parts of the transformer + for instance if C = 384 (the number of channels in the transformer) and + t_hidden_scale = 4.0 then the intermediate layer of the FFN has dimension + 384 * 4 = 1536 + t_heads: number of heads for the transformer + t_dropout: dropout in the transformer + t_max_positions: max_positions for the "scaled" positional embedding, only + useful if t_emb="scaled" + t_norm_in: (bool) norm before addinf positional embedding and getting into the + transformer layers + t_norm_in_group: (bool) if True while t_norm_in=True, the norm is on all the + timesteps (GroupNorm with group=1) + t_group_norm: (bool) if True, the norms of the Encoder Layers are on all the + timesteps (GroupNorm with group=1) + t_norm_first: (bool) if True the norm is before the attention and before the FFN + t_norm_out: (bool) if True, there is a GroupNorm (group=1) at the end of each layer + t_max_period: (float) denominator in the sinusoidal embedding expression + t_weight_decay: (float) weight decay for the transformer + t_lr: (float) specific learning rate for the transformer + t_layer_scale: (bool) Layer Scale for the transformer + t_gelu: (bool) activations of the transformer are GeLU if True, ReLU else + t_weight_pos_embed: (float) weighting of the positional embedding + t_cape_mean_normalize: (bool) if t_emb="cape", normalisation of positional embeddings + see: https://arxiv.org/abs/2106.03143 + t_cape_augment: (bool) if t_emb="cape", must be True during training and False + during the inference, see: https://arxiv.org/abs/2106.03143 + t_cape_glob_loc_scale: (list of 3 floats) if t_emb="cape", CAPE parameters + see: https://arxiv.org/abs/2106.03143 + t_sparse_self_attn: (bool) if True, the self attentions are sparse + t_sparse_cross_attn: (bool) if True, the cross-attentions are sparse (don't use it + unless you designed really specific masks) + t_mask_type: (str) can be "diag", "jmask", "random", "global" or any combination + with '_' between: i.e. "diag_jmask_random" (note that this is permutation + invariant i.e. "diag_jmask_random" is equivalent to "jmask_random_diag") + t_mask_random_seed: (int) if "random" is in t_mask_type, controls the seed + that generated the random part of the mask + t_sparse_attn_window: (int) if "diag" is in t_mask_type, for a query (i), and + a key (j), the mask is True id |i-j|<=t_sparse_attn_window + t_global_window: (int) if "global" is in t_mask_type, mask[:t_global_window, :] + and mask[:, :t_global_window] will be True + t_sparsity: (float) if "random" is in t_mask_type, t_sparsity is the sparsity + level of the random part of the mask. + t_cross_first: (bool) if True cross attention is the first layer of the + transformer (False seems to be better) + rescale: weight rescaling trick + use_train_segment: (bool) if True, the actual size that is used during the + training is used during inference. + """ + super().__init__() + self.cac = cac + self.wiener_residual = wiener_residual + self.audio_channels = audio_channels + self.sources = sources + self.kernel_size = kernel_size + self.context = context + self.stride = stride + self.depth = depth + self.bottom_channels = bottom_channels + self.channels = channels + self.samplerate = samplerate + self.segment = segment + self.use_train_segment = use_train_segment + self.nfft = nfft + self.hop_length = nfft // 4 + self.wiener_iters = wiener_iters + self.end_iters = end_iters + self.freq_emb = None + assert wiener_iters == end_iters + + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + + self.tencoder = nn.ModuleList() + self.tdecoder = nn.ModuleList() + + chin = audio_channels + chin_z = chin # number of channels for the freq branch + if self.cac: + chin_z *= 2 + chout = channels_time or channels + chout_z = channels + freqs = nfft // 2 + + for index in range(depth): + norm = index >= norm_starts + freq = freqs > 1 + stri = stride + ker = kernel_size + if not freq: + assert freqs == 1 + ker = time_stride * 2 + stri = time_stride + + pad = True + last_freq = False + if freq and freqs <= kernel_size: + ker = freqs + pad = False + last_freq = True + + kw = { + "kernel_size": ker, + "stride": stri, + "freq": freq, + "pad": pad, + "norm": norm, + "rewrite": rewrite, + "norm_groups": norm_groups, + "dconv_kw": { + "depth": dconv_depth, + "compress": dconv_comp, + "init": dconv_init, + "gelu": True, + }, + } + kwt = dict(kw) + kwt["freq"] = 0 + kwt["kernel_size"] = kernel_size + kwt["stride"] = stride + kwt["pad"] = True + kw_dec = dict(kw) + multi = False + if multi_freqs and index < multi_freqs_depth: + multi = True + kw_dec["context_freq"] = False + + if last_freq: + chout_z = max(chout, chout_z) + chout = chout_z + + enc = HEncLayer( + chin_z, chout_z, dconv=dconv_mode & 1, context=context_enc, **kw + ) + if freq: + tenc = HEncLayer( + chin, + chout, + dconv=dconv_mode & 1, + context=context_enc, + empty=last_freq, + **kwt + ) + self.tencoder.append(tenc) + + if multi: + enc = MultiWrap(enc, multi_freqs) + self.encoder.append(enc) + if index == 0: + chin = self.audio_channels * len(self.sources) + chin_z = chin + if self.cac: + chin_z *= 2 + dec = HDecLayer( + chout_z, + chin_z, + dconv=dconv_mode & 2, + last=index == 0, + context=context, + **kw_dec + ) + if multi: + dec = MultiWrap(dec, multi_freqs) + if freq: + tdec = HDecLayer( + chout, + chin, + dconv=dconv_mode & 2, + empty=last_freq, + last=index == 0, + context=context, + **kwt + ) + self.tdecoder.insert(0, tdec) + self.decoder.insert(0, dec) + + chin = chout + chin_z = chout_z + chout = int(growth * chout) + chout_z = int(growth * chout_z) + if freq: + if freqs <= kernel_size: + freqs = 1 + else: + freqs //= stride + if index == 0 and freq_emb: + self.freq_emb = ScaledEmbedding( + freqs, chin_z, smooth=emb_smooth, scale=emb_scale + ) + self.freq_emb_scale = freq_emb + + if rescale: + rescale_module(self, reference=rescale) + + transformer_channels = channels * growth ** (depth - 1) + if bottom_channels: + self.channel_upsampler = nn.Conv1d(transformer_channels, bottom_channels, 1) + self.channel_downsampler = nn.Conv1d( + bottom_channels, transformer_channels, 1 + ) + self.channel_upsampler_t = nn.Conv1d( + transformer_channels, bottom_channels, 1 + ) + self.channel_downsampler_t = nn.Conv1d( + bottom_channels, transformer_channels, 1 + ) + + transformer_channels = bottom_channels + + if t_layers > 0: + self.crosstransformer = CrossTransformerEncoder( + dim=transformer_channels, + emb=t_emb, + hidden_scale=t_hidden_scale, + num_heads=t_heads, + num_layers=t_layers, + cross_first=t_cross_first, + dropout=t_dropout, + max_positions=t_max_positions, + norm_in=t_norm_in, + norm_in_group=t_norm_in_group, + group_norm=t_group_norm, + norm_first=t_norm_first, + norm_out=t_norm_out, + max_period=t_max_period, + weight_decay=t_weight_decay, + lr=t_lr, + layer_scale=t_layer_scale, + gelu=t_gelu, + sin_random_shift=t_sin_random_shift, + weight_pos_embed=t_weight_pos_embed, + cape_mean_normalize=t_cape_mean_normalize, + cape_augment=t_cape_augment, + cape_glob_loc_scale=t_cape_glob_loc_scale, + sparse_self_attn=t_sparse_self_attn, + sparse_cross_attn=t_sparse_cross_attn, + mask_type=t_mask_type, + mask_random_seed=t_mask_random_seed, + sparse_attn_window=t_sparse_attn_window, + global_window=t_global_window, + sparsity=t_sparsity, + auto_sparsity=t_auto_sparsity, + ) + else: + self.crosstransformer = None + + def _spec(self, x): + hl = self.hop_length + nfft = self.nfft + x0 = x # noqa + + # We re-pad the signal in order to keep the property + # that the size of the output is exactly the size of the input + # divided by the stride (here hop_length), when divisible. + # This is achieved by padding by 1/4th of the kernel size (here nfft). + # which is not supported by torch.stft. + # Having all convolution operations follow this convention allow to easily + # align the time and frequency branches later on. + assert hl == nfft // 4 + le = int(math.ceil(x.shape[-1] / hl)) + pad = hl // 2 * 3 + x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode="reflect") + + z = spectro(x, nfft, hl)[..., :-1, :] + assert z.shape[-1] == le + 4, (z.shape, x.shape, le) + z = z[..., 2: 2 + le] + return z + + def _ispec(self, z, length=None, scale=0): + hl = self.hop_length // (4**scale) + z = F.pad(z, (0, 0, 0, 1)) + z = F.pad(z, (2, 2)) + pad = hl // 2 * 3 + le = hl * int(math.ceil(length / hl)) + 2 * pad + x = ispectro(z, hl, length=le) + x = x[..., pad: pad + length] + return x + + def _magnitude(self, z): + # return the magnitude of the spectrogram, except when cac is True, + # in which case we just move the complex dimension to the channel one. + if self.cac: + B, C, Fr, T = z.shape + m = torch.view_as_real(z).permute(0, 1, 4, 2, 3) + m = m.reshape(B, C * 2, Fr, T) + else: + m = z.abs() + return m + + def _mask(self, z, m): + # Apply masking given the mixture spectrogram `z` and the estimated mask `m`. + # If `cac` is True, `m` is actually a full spectrogram and `z` is ignored. + niters = self.wiener_iters + if self.cac: + B, S, C, Fr, T = m.shape + out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3) + out = torch.view_as_complex(out.contiguous()) + return out + if self.training: + niters = self.end_iters + if niters < 0: + z = z[:, None] + return z / (1e-8 + z.abs()) * m + else: + return self._wiener(m, z, niters) + + def _wiener(self, mag_out, mix_stft, niters): + # apply wiener filtering from OpenUnmix. + init = mix_stft.dtype + wiener_win_len = 300 + residual = self.wiener_residual + + B, S, C, Fq, T = mag_out.shape + mag_out = mag_out.permute(0, 4, 3, 2, 1) + mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1)) + + outs = [] + for sample in range(B): + pos = 0 + out = [] + for pos in range(0, T, wiener_win_len): + frame = slice(pos, pos + wiener_win_len) + z_out = wiener( + mag_out[sample, frame], + mix_stft[sample, frame], + niters, + residual=residual, + ) + out.append(z_out.transpose(-1, -2)) + outs.append(torch.cat(out, dim=0)) + out = torch.view_as_complex(torch.stack(outs, 0)) + out = out.permute(0, 4, 3, 2, 1).contiguous() + if residual: + out = out[:, :-1] + assert list(out.shape) == [B, S, C, Fq, T] + return out.to(init) + + def valid_length(self, length: int): + """ + Return a length that is appropriate for evaluation. + In our case, always return the training length, unless + it is smaller than the given length, in which case this + raises an error. + """ + if not self.use_train_segment: + return length + training_length = int(self.segment * self.samplerate) + if training_length < length: + raise ValueError( + f"Given length {length} is longer than " + f"training length {training_length}") + return training_length + + def forward(self, mix): + length = mix.shape[-1] + length_pre_pad = None + if self.use_train_segment: + if self.training: + self.segment = Fraction(mix.shape[-1], self.samplerate) + else: + training_length = int(self.segment * self.samplerate) + if mix.shape[-1] < training_length: + length_pre_pad = mix.shape[-1] + mix = F.pad(mix, (0, training_length - length_pre_pad)) + z = self._spec(mix) + mag = self._magnitude(z).to(mix.device) + x = mag + + B, C, Fq, T = x.shape + + # unlike previous Demucs, we always normalize because it is easier. + mean = x.mean(dim=(1, 2, 3), keepdim=True) + std = x.std(dim=(1, 2, 3), keepdim=True) + x = (x - mean) / (1e-5 + std) + # x will be the freq. branch input. + + # Prepare the time branch input. + xt = mix + meant = xt.mean(dim=(1, 2), keepdim=True) + stdt = xt.std(dim=(1, 2), keepdim=True) + xt = (xt - meant) / (1e-5 + stdt) + + # okay, this is a giant mess I know... + saved = [] # skip connections, freq. + saved_t = [] # skip connections, time. + lengths = [] # saved lengths to properly remove padding, freq branch. + lengths_t = [] # saved lengths for time branch. + for idx, encode in enumerate(self.encoder): + lengths.append(x.shape[-1]) + inject = None + if idx < len(self.tencoder): + # we have not yet merged branches. + lengths_t.append(xt.shape[-1]) + tenc = self.tencoder[idx] + xt = tenc(xt) + if not tenc.empty: + # save for skip connection + saved_t.append(xt) + else: + # tenc contains just the first conv., so that now time and freq. + # branches have the same shape and can be merged. + inject = xt + x = encode(x, inject) + if idx == 0 and self.freq_emb is not None: + # add frequency embedding to allow for non equivariant convolutions + # over the frequency axis. + frs = torch.arange(x.shape[-2], device=x.device) + emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x) + x = x + self.freq_emb_scale * emb + + saved.append(x) + if self.crosstransformer: + if self.bottom_channels: + b, c, f, t = x.shape + x = rearrange(x, "b c f t-> b c (f t)") + x = self.channel_upsampler(x) + x = rearrange(x, "b c (f t)-> b c f t", f=f) + xt = self.channel_upsampler_t(xt) + + x, xt = self.crosstransformer(x, xt) + + if self.bottom_channels: + x = rearrange(x, "b c f t-> b c (f t)") + x = self.channel_downsampler(x) + x = rearrange(x, "b c (f t)-> b c f t", f=f) + xt = self.channel_downsampler_t(xt) + + for idx, decode in enumerate(self.decoder): + skip = saved.pop(-1) + x, pre = decode(x, skip, lengths.pop(-1)) + # `pre` contains the output just before final transposed convolution, + # which is used when the freq. and time branch separate. + + offset = self.depth - len(self.tdecoder) + if idx >= offset: + tdec = self.tdecoder[idx - offset] + length_t = lengths_t.pop(-1) + if tdec.empty: + assert pre.shape[2] == 1, pre.shape + pre = pre[:, :, 0] + xt, _ = tdec(pre, None, length_t) + else: + skip = saved_t.pop(-1) + xt, _ = tdec(xt, skip, length_t) + + # Let's make sure we used all stored skip connections. + assert len(saved) == 0 + assert len(lengths_t) == 0 + assert len(saved_t) == 0 + + S = len(self.sources) + x = x.view(B, S, -1, Fq, T) + x = x * std[:, None] + mean[:, None] + + # to cpu as mps doesnt support complex numbers + # demucs issue #435 ##432 + # NOTE: in this case z already is on cpu + # TODO: remove this when mps supports complex numbers + x_is_mps = x.device.type == "mps" + if x_is_mps: + x = x.cpu() + + zout = self._mask(z, x) + if self.use_train_segment: + if self.training: + x = self._ispec(zout, length) + else: + x = self._ispec(zout, training_length) + else: + x = self._ispec(zout, length) + + # back to mps device + if x_is_mps: + x = x.to("mps") + + if self.use_train_segment: + if self.training: + xt = xt.view(B, S, -1, length) + else: + xt = xt.view(B, S, -1, training_length) + else: + xt = xt.view(B, S, -1, length) + xt = xt * stdt[:, None] + meant[:, None] + x = xt + x + if length_pre_pad: + x = x[..., :length_pre_pad] + return x diff --git a/AIMeiSheng/demucs/pretrained.py b/AIMeiSheng/demucs/pretrained.py new file mode 100644 index 0000000..80ae49c --- /dev/null +++ b/AIMeiSheng/demucs/pretrained.py @@ -0,0 +1,98 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Loading pretrained models. +""" + +import logging +from pathlib import Path +import typing as tp + +from dora.log import fatal, bold + +from .hdemucs import HDemucs +from .repo import RemoteRepo, LocalRepo, ModelOnlyRepo, BagOnlyRepo, AnyModelRepo, ModelLoadingError # noqa +from .states import _check_diffq + +logger = logging.getLogger(__name__) +ROOT_URL = "https://dl.fbaipublicfiles.com/demucs/" +REMOTE_ROOT = Path(__file__).parent / 'remote' + +SOURCES = ["drums", "bass", "other", "vocals"] +DEFAULT_MODEL = 'htdemucs' + + +def demucs_unittest(): + model = HDemucs(channels=4, sources=SOURCES) + return model + + +def add_model_flags(parser): + group = parser.add_mutually_exclusive_group(required=False) + group.add_argument("-s", "--sig", help="Locally trained XP signature.") + group.add_argument("-n", "--name", default="htdemucs", + help="Pretrained model name or signature. Default is htdemucs.") + parser.add_argument("--repo", type=Path, + help="Folder containing all pre-trained models for use with -n.") + + +def _parse_remote_files(remote_file_list) -> tp.Dict[str, str]: + root: str = '' + models: tp.Dict[str, str] = {} + for line in remote_file_list.read_text().split('\n'): + line = line.strip() + if line.startswith('#'): + continue + elif len(line) == 0: + continue + elif line.startswith('root:'): + root = line.split(':', 1)[1].strip() + else: + sig = line.split('-', 1)[0] + assert sig not in models + models[sig] = ROOT_URL + root + line + return models + + +def get_model(name: str, + repo: tp.Optional[Path] = None): + """`name` must be a bag of models name or a pretrained signature + from the remote AWS model repo or the specified local repo if `repo` is not None. + """ + if name == 'demucs_unittest': + return demucs_unittest() + model_repo: ModelOnlyRepo + if repo is None: + models = _parse_remote_files(REMOTE_ROOT / 'files.txt') + model_repo = RemoteRepo(models) + bag_repo = BagOnlyRepo(REMOTE_ROOT, model_repo) + else: + if not repo.is_dir(): + fatal(f"{repo} must exist and be a directory.") + model_repo = LocalRepo(repo) + bag_repo = BagOnlyRepo(repo, model_repo) + any_repo = AnyModelRepo(model_repo, bag_repo) + try: + model = any_repo.get_model(name) + except ImportError as exc: + if 'diffq' in exc.args[0]: + _check_diffq() + raise + + model.eval() + return model + + +def get_model_from_args(args): + """ + Load local model package or pre-trained model. + """ + if args.name is None: + args.name = DEFAULT_MODEL + print(bold("Important: the default model was recently changed to `htdemucs`"), + "the latest Hybrid Transformer Demucs model. In some cases, this model can " + "actually perform worse than previous models. To get back the old default model " + "use `-n mdx_extra_q`.") + return get_model(name=args.name, repo=args.repo) diff --git a/AIMeiSheng/demucs/py.typed b/AIMeiSheng/demucs/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/AIMeiSheng/demucs/remote/files.txt b/AIMeiSheng/demucs/remote/files.txt new file mode 100644 index 0000000..346eb33 --- /dev/null +++ b/AIMeiSheng/demucs/remote/files.txt @@ -0,0 +1,32 @@ +# MDX Models +root: mdx_final/ +0d19c1c6-0f06f20e.th +5d2d6c55-db83574e.th +7d865c68-3d5dd56b.th +7ecf8ec1-70f50cc9.th +a1d90b5c-ae9d2452.th +c511e2ab-fe698775.th +cfa93e08-61801ae1.th +e51eebcc-c1b80bdd.th +6b9c2ca1-3fd82607.th +b72baf4e-8778635e.th +42e558d4-196e0e1b.th +305bc58f-18378783.th +14fc6a69-a89dd0ee.th +464b36d7-e5a9386e.th +7fd6ef75-a905dd85.th +83fc094f-4a16d450.th +1ef250f1-592467ce.th +902315c2-b39ce9c9.th +9a6b4851-03af0aa6.th +fa0cb7f9-100d8bf4.th +# Hybrid Transformer models +root: hybrid_transformer/ +955717e8-8726e21a.th +f7e0c4bc-ba3fe64a.th +d12395a8-e57c48e6.th +92cfc3b6-ef3bcb9c.th +04573f0d-f3cf25b2.th +75fc33f5-1941ce65.th +# Experimental 6 sources model +5c90dfd2-34c22ccb.th diff --git a/AIMeiSheng/demucs/remote/hdemucs_mmi.yaml b/AIMeiSheng/demucs/remote/hdemucs_mmi.yaml new file mode 100644 index 0000000..0ea0891 --- /dev/null +++ b/AIMeiSheng/demucs/remote/hdemucs_mmi.yaml @@ -0,0 +1,2 @@ +models: ['75fc33f5'] +segment: 44 diff --git a/AIMeiSheng/demucs/remote/htdemucs.yaml b/AIMeiSheng/demucs/remote/htdemucs.yaml new file mode 100644 index 0000000..0d5f208 --- /dev/null +++ b/AIMeiSheng/demucs/remote/htdemucs.yaml @@ -0,0 +1 @@ +models: ['955717e8'] diff --git a/AIMeiSheng/demucs/remote/htdemucs_6s.yaml b/AIMeiSheng/demucs/remote/htdemucs_6s.yaml new file mode 100644 index 0000000..651a0fa --- /dev/null +++ b/AIMeiSheng/demucs/remote/htdemucs_6s.yaml @@ -0,0 +1 @@ +models: ['5c90dfd2'] diff --git a/AIMeiSheng/demucs/remote/htdemucs_ft.yaml b/AIMeiSheng/demucs/remote/htdemucs_ft.yaml new file mode 100644 index 0000000..ba5c69c --- /dev/null +++ b/AIMeiSheng/demucs/remote/htdemucs_ft.yaml @@ -0,0 +1,7 @@ +models: ['f7e0c4bc', 'd12395a8', '92cfc3b6', '04573f0d'] +weights: [ + [1., 0., 0., 0.], + [0., 1., 0., 0.], + [0., 0., 1., 0.], + [0., 0., 0., 1.], +] \ No newline at end of file diff --git a/AIMeiSheng/demucs/remote/mdx.yaml b/AIMeiSheng/demucs/remote/mdx.yaml new file mode 100644 index 0000000..4e81a50 --- /dev/null +++ b/AIMeiSheng/demucs/remote/mdx.yaml @@ -0,0 +1,8 @@ +models: ['0d19c1c6', '7ecf8ec1', 'c511e2ab', '7d865c68'] +weights: [ + [1., 1., 0., 0.], + [0., 1., 0., 0.], + [1., 0., 1., 1.], + [1., 0., 1., 1.], +] +segment: 44 diff --git a/AIMeiSheng/demucs/remote/mdx_extra.yaml b/AIMeiSheng/demucs/remote/mdx_extra.yaml new file mode 100644 index 0000000..847bf66 --- /dev/null +++ b/AIMeiSheng/demucs/remote/mdx_extra.yaml @@ -0,0 +1,2 @@ +models: ['e51eebcc', 'a1d90b5c', '5d2d6c55', 'cfa93e08'] +segment: 44 \ No newline at end of file diff --git a/AIMeiSheng/demucs/remote/mdx_extra_q.yaml b/AIMeiSheng/demucs/remote/mdx_extra_q.yaml new file mode 100644 index 0000000..87702bc --- /dev/null +++ b/AIMeiSheng/demucs/remote/mdx_extra_q.yaml @@ -0,0 +1,2 @@ +models: ['83fc094f', '464b36d7', '14fc6a69', '7fd6ef75'] +segment: 44 diff --git a/AIMeiSheng/demucs/remote/mdx_q.yaml b/AIMeiSheng/demucs/remote/mdx_q.yaml new file mode 100644 index 0000000..827d2c6 --- /dev/null +++ b/AIMeiSheng/demucs/remote/mdx_q.yaml @@ -0,0 +1,8 @@ +models: ['6b9c2ca1', 'b72baf4e', '42e558d4', '305bc58f'] +weights: [ + [1., 1., 0., 0.], + [0., 1., 0., 0.], + [1., 0., 1., 1.], + [1., 0., 1., 1.], +] +segment: 44 diff --git a/AIMeiSheng/demucs/remote/repro_mdx_a.yaml b/AIMeiSheng/demucs/remote/repro_mdx_a.yaml new file mode 100644 index 0000000..691abc2 --- /dev/null +++ b/AIMeiSheng/demucs/remote/repro_mdx_a.yaml @@ -0,0 +1,2 @@ +models: ['9a6b4851', '1ef250f1', 'fa0cb7f9', '902315c2'] +segment: 44 diff --git a/AIMeiSheng/demucs/remote/repro_mdx_a_hybrid_only.yaml b/AIMeiSheng/demucs/remote/repro_mdx_a_hybrid_only.yaml new file mode 100644 index 0000000..78eb8e0 --- /dev/null +++ b/AIMeiSheng/demucs/remote/repro_mdx_a_hybrid_only.yaml @@ -0,0 +1,2 @@ +models: ['fa0cb7f9', '902315c2', 'fa0cb7f9', '902315c2'] +segment: 44 diff --git a/AIMeiSheng/demucs/remote/repro_mdx_a_time_only.yaml b/AIMeiSheng/demucs/remote/repro_mdx_a_time_only.yaml new file mode 100644 index 0000000..d5d16ea --- /dev/null +++ b/AIMeiSheng/demucs/remote/repro_mdx_a_time_only.yaml @@ -0,0 +1,2 @@ +models: ['9a6b4851', '9a6b4851', '1ef250f1', '1ef250f1'] +segment: 44 diff --git a/AIMeiSheng/demucs/repitch.py b/AIMeiSheng/demucs/repitch.py new file mode 100644 index 0000000..ebef736 --- /dev/null +++ b/AIMeiSheng/demucs/repitch.py @@ -0,0 +1,86 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Utility for on the fly pitch/tempo change for data augmentation.""" + +import random +import subprocess as sp +import tempfile + +import torch +import torchaudio as ta + +from .audio import save_audio + + +class RepitchedWrapper: + """ + Wrap a dataset to apply online change of pitch / tempo. + """ + def __init__(self, dataset, proba=0.2, max_pitch=2, max_tempo=12, + tempo_std=5, vocals=[3], same=True): + self.dataset = dataset + self.proba = proba + self.max_pitch = max_pitch + self.max_tempo = max_tempo + self.tempo_std = tempo_std + self.same = same + self.vocals = vocals + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, index): + streams = self.dataset[index] + in_length = streams.shape[-1] + out_length = int((1 - 0.01 * self.max_tempo) * in_length) + + if random.random() < self.proba: + outs = [] + for idx, stream in enumerate(streams): + if idx == 0 or not self.same: + delta_pitch = random.randint(-self.max_pitch, self.max_pitch) + delta_tempo = random.gauss(0, self.tempo_std) + delta_tempo = min(max(-self.max_tempo, delta_tempo), self.max_tempo) + stream = repitch( + stream, + delta_pitch, + delta_tempo, + voice=idx in self.vocals) + outs.append(stream[:, :out_length]) + streams = torch.stack(outs) + else: + streams = streams[..., :out_length] + return streams + + +def repitch(wav, pitch, tempo, voice=False, quick=False, samplerate=44100): + """ + tempo is a relative delta in percentage, so tempo=10 means tempo at 110%! + pitch is in semi tones. + Requires `soundstretch` to be installed, see + https://www.surina.net/soundtouch/soundstretch.html + """ + infile = tempfile.NamedTemporaryFile(suffix=".wav") + outfile = tempfile.NamedTemporaryFile(suffix=".wav") + save_audio(wav, infile.name, samplerate, clip='clamp') + command = [ + "soundstretch", + infile.name, + outfile.name, + f"-pitch={pitch}", + f"-tempo={tempo:.6f}", + ] + if quick: + command += ["-quick"] + if voice: + command += ["-speech"] + try: + sp.run(command, capture_output=True, check=True) + except sp.CalledProcessError as error: + raise RuntimeError(f"Could not change bpm because {error.stderr.decode('utf-8')}") + wav, sr = ta.load(outfile.name) + assert sr == samplerate + return wav diff --git a/AIMeiSheng/demucs/repo.py b/AIMeiSheng/demucs/repo.py new file mode 100644 index 0000000..75f2afa --- /dev/null +++ b/AIMeiSheng/demucs/repo.py @@ -0,0 +1,180 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Represents a model repository, including pre-trained models and bags of models. +A repo can either be the main remote repository stored in AWS, or a local repository +with your own models. +""" + +from hashlib import sha256 +from pathlib import Path +import typing as tp + +import torch +import yaml +import os + +from .apply import BagOfModels, Model +from .states import load_model +from AIMeiSheng.docker_demo.common import gs_demucs_model_path + +AnyModel = tp.Union[Model, BagOfModels] + + +class ModelLoadingError(RuntimeError): + pass + + +def check_checksum(path: Path, checksum: str): + sha = sha256() + with open(path, 'rb') as file: + while True: + buf = file.read(2**20) + if not buf: + break + sha.update(buf) + actual_checksum = sha.hexdigest()[:len(checksum)] + if actual_checksum != checksum: + raise ModelLoadingError(f'Invalid checksum for file {path}, ' + f'expected {checksum} but got {actual_checksum}') + + +class ModelOnlyRepo: + """Base class for all model only repos. + """ + def has_model(self, sig: str) -> bool: + raise NotImplementedError() + + def get_model(self, sig: str) -> Model: + raise NotImplementedError() + + def list_model(self) -> tp.Dict[str, tp.Union[str, Path]]: + raise NotImplementedError() + + +class RemoteRepo(ModelOnlyRepo): + def __init__(self, models: tp.Dict[str, str]): + self._models = models + + def has_model(self, sig: str) -> bool: + return sig in self._models + + def get_model(self, sig: str) -> Model: + try: + url = self._models[sig] + except KeyError: + raise ModelLoadingError(f'Could not find a pre-trained model with signature {sig}.') + #''' + path_dir = gs_demucs_model_path #"/data/bingxiao.fang/voice_conversion/svc_meisheng/AIMeiSheng/demucs_model" + url_model = url.split('/')[-1] + path_url_model = os.path.join(path_dir, url_model) + print("@@@@path_url_model:", path_url_model) + pkg = torch.load(path_url_model) + #''' + + ''' + print("@@@@url:", url) + pkg = torch.hub.load_state_dict_from_url( + url, map_location='cpu', check_hash=True) # type: ignore + + #''' + #print("@@@@@pkg:",pkg) + return load_model(pkg) + + def list_model(self) -> tp.Dict[str, tp.Union[str, Path]]: + return self._models # type: ignore + + +class LocalRepo(ModelOnlyRepo): + def __init__(self, root: Path): + self.root = root + self.scan() + + def scan(self): + self._models = {} + self._checksums = {} + for file in self.root.iterdir(): + if file.suffix == '.th': + if '-' in file.stem: + xp_sig, checksum = file.stem.split('-') + self._checksums[xp_sig] = checksum + else: + xp_sig = file.stem + if xp_sig in self._models: + raise ModelLoadingError( + f'Duplicate pre-trained model exist for signature {xp_sig}. ' + 'Please delete all but one.') + self._models[xp_sig] = file + + def has_model(self, sig: str) -> bool: + return sig in self._models + + def get_model(self, sig: str) -> Model: + try: + file = self._models[sig] + except KeyError: + raise ModelLoadingError(f'Could not find pre-trained model with signature {sig}.') + if sig in self._checksums: + check_checksum(file, self._checksums[sig]) + return load_model(file) + + def list_model(self) -> tp.Dict[str, tp.Union[str, Path]]: + return self._models + + +class BagOnlyRepo: + """Handles only YAML files containing bag of models, leaving the actual + model loading to some Repo. + """ + def __init__(self, root: Path, model_repo: ModelOnlyRepo): + self.root = root + self.model_repo = model_repo + self.scan() + + def scan(self): + self._bags = {} + for file in self.root.iterdir(): + if file.suffix == '.yaml': + self._bags[file.stem] = file + + def has_model(self, name: str) -> bool: + return name in self._bags + + def get_model(self, name: str) -> BagOfModels: + try: + yaml_file = self._bags[name] + except KeyError: + raise ModelLoadingError(f'{name} is neither a single pre-trained model or ' + 'a bag of models.') + bag = yaml.safe_load(open(yaml_file)) + signatures = bag['models'] + models = [self.model_repo.get_model(sig) for sig in signatures] + weights = bag.get('weights') + segment = bag.get('segment') + return BagOfModels(models, weights, segment) + + def list_model(self) -> tp.Dict[str, tp.Union[str, Path]]: + return self._bags + + +class AnyModelRepo: + def __init__(self, model_repo: ModelOnlyRepo, bag_repo: BagOnlyRepo): + self.model_repo = model_repo + self.bag_repo = bag_repo + + def has_model(self, name_or_sig: str) -> bool: + return self.model_repo.has_model(name_or_sig) or self.bag_repo.has_model(name_or_sig) + + def get_model(self, name_or_sig: str) -> AnyModel: + if self.model_repo.has_model(name_or_sig): + return self.model_repo.get_model(name_or_sig) + else: + return self.bag_repo.get_model(name_or_sig) + + def list_model(self) -> tp.Dict[str, tp.Union[str, Path]]: + models = self.model_repo.list_model() + for key, value in self.bag_repo.list_model().items(): + models[key] = value + return models diff --git a/AIMeiSheng/demucs/separate.py b/AIMeiSheng/demucs/separate.py new file mode 100644 index 0000000..d5102ed --- /dev/null +++ b/AIMeiSheng/demucs/separate.py @@ -0,0 +1,222 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import sys +from pathlib import Path + +from dora.log import fatal +import torch as th + +from .api import Separator, save_audio, list_models + +from .apply import BagOfModels +from .htdemucs import HTDemucs +from .pretrained import add_model_flags, ModelLoadingError + + +def get_parser(): + parser = argparse.ArgumentParser("demucs.separate", + description="Separate the sources for the given tracks") + parser.add_argument("tracks", nargs='*', type=Path, default=[], help='Path to tracks') + add_model_flags(parser) + parser.add_argument("--list-models", action="store_true", help="List available models " + "from current repo and exit") + parser.add_argument("-v", "--verbose", action="store_true") + parser.add_argument("-o", + "--out", + type=Path, + default=Path("separated"), + help="Folder where to put extracted tracks. A subfolder " + "with the model name will be created.") + parser.add_argument("--filename", + default="{track}/{stem}.{ext}", + help="Set the name of output file. \n" + 'Use "{track}", "{trackext}", "{stem}", "{ext}" to use ' + "variables of track name without extension, track extension, " + "stem name and default output file extension. \n" + 'Default is "{track}/{stem}.{ext}".') + parser.add_argument("-d", + "--device", + default="cuda" if th.cuda.is_available() else "cpu", + help="Device to use, default is cuda if available else cpu") + parser.add_argument("--shifts", + default=1, + type=int, + help="Number of random shifts for equivariant stabilization." + "Increase separation time but improves quality for Demucs. 10 was used " + "in the original paper.") + parser.add_argument("--overlap", + default=0.25, + type=float, + help="Overlap between the splits.") + split_group = parser.add_mutually_exclusive_group() + split_group.add_argument("--no-split", + action="store_false", + dest="split", + default=True, + help="Doesn't split audio in chunks. " + "This can use large amounts of memory.") + split_group.add_argument("--segment", type=int, + help="Set split size of each chunk. " + "This can help save memory of graphic card. ") + parser.add_argument("--two-stems", + dest="stem", metavar="STEM", + help="Only separate audio into {STEM} and no_{STEM}. ") + parser.add_argument("--other-method", dest="other_method", choices=["none", "add", "minus"], + default="add", help='Decide how to get "no_{STEM}". "none" will not save ' + '"no_{STEM}". "add" will add all the other stems. "minus" will use the ' + "original track minus the selected stem.") + depth_group = parser.add_mutually_exclusive_group() + depth_group.add_argument("--int24", action="store_true", + help="Save wav output as 24 bits wav.") + depth_group.add_argument("--float32", action="store_true", + help="Save wav output as float32 (2x bigger).") + parser.add_argument("--clip-mode", default="rescale", choices=["rescale", "clamp", "none"], + help="Strategy for avoiding clipping: rescaling entire signal " + "if necessary (rescale) or hard clipping (clamp).") + format_group = parser.add_mutually_exclusive_group() + format_group.add_argument("--flac", action="store_true", + help="Convert the output wavs to flac.") + format_group.add_argument("--mp3", action="store_true", + help="Convert the output wavs to mp3.") + parser.add_argument("--mp3-bitrate", + default=320, + type=int, + help="Bitrate of converted mp3.") + parser.add_argument("--mp3-preset", choices=range(2, 8), type=int, default=2, + help="Encoder preset of MP3, 2 for highest quality, 7 for " + "fastest speed. Default is 2") + parser.add_argument("-j", "--jobs", + default=0, + type=int, + help="Number of jobs. This can increase memory usage but will " + "be much faster when multiple cores are available.") + + return parser + + +def main(opts=None): + parser = get_parser() + args = parser.parse_args(opts) + if args.list_models: + models = list_models(args.repo) + print("Bag of models:", end="\n ") + print("\n ".join(models["bag"])) + print("Single models:", end="\n ") + print("\n ".join(models["single"])) + sys.exit(0) + if len(args.tracks) == 0: + print("error: the following arguments are required: tracks", file=sys.stderr) + sys.exit(1) + + try: + separator = Separator(model=args.name, + repo=args.repo, + device=args.device, + shifts=args.shifts, + split=args.split, + overlap=args.overlap, + progress=True, + jobs=args.jobs, + segment=args.segment) + except ModelLoadingError as error: + fatal(error.args[0]) + + max_allowed_segment = float('inf') + if isinstance(separator.model, HTDemucs): + max_allowed_segment = float(separator.model.segment) + elif isinstance(separator.model, BagOfModels): + max_allowed_segment = separator.model.max_allowed_segment + if args.segment is not None and args.segment > max_allowed_segment: + fatal("Cannot use a Transformer model with a longer segment " + f"than it was trained for. Maximum segment is: {max_allowed_segment}") + + if isinstance(separator.model, BagOfModels): + print( + f"Selected model is a bag of {len(separator.model.models)} models. " + "You will see that many progress bars per track." + ) + + if args.stem is not None and args.stem not in separator.model.sources: + fatal( + 'error: stem "{stem}" is not in selected model. ' + "STEM must be one of {sources}.".format( + stem=args.stem, sources=", ".join(separator.model.sources) + ) + ) + out = args.out / args.name + out.mkdir(parents=True, exist_ok=True) + print(f"Separated tracks will be stored in {out.resolve()}") + for track in args.tracks: + if not track.exists(): + print(f"File {track} does not exist. If the path contains spaces, " + 'please try again after surrounding the entire path with quotes "".', + file=sys.stderr) + continue + print(f"Separating track {track}") + + origin, res = separator.separate_audio_file(track) + + if args.mp3: + ext = "mp3" + elif args.flac: + ext = "flac" + else: + ext = "wav" + kwargs = { + "samplerate": separator.samplerate, + "bitrate": args.mp3_bitrate, + "preset": args.mp3_preset, + "clip": args.clip_mode, + "as_float": args.float32, + "bits_per_sample": 24 if args.int24 else 16, + } + if args.stem is None: + for name, source in res.items(): + stem = out / args.filename.format( + track=track.name.rsplit(".", 1)[0], + trackext=track.name.rsplit(".", 1)[-1], + stem=name, + ext=ext, + ) + stem.parent.mkdir(parents=True, exist_ok=True) + save_audio(source, str(stem), **kwargs) + else: + stem = out / args.filename.format( + track=track.name.rsplit(".", 1)[0], + trackext=track.name.rsplit(".", 1)[-1], + stem="minus_" + args.stem, + ext=ext, + ) + if args.other_method == "minus": + stem.parent.mkdir(parents=True, exist_ok=True) + save_audio(origin - res[args.stem], str(stem), **kwargs) + stem = out / args.filename.format( + track=track.name.rsplit(".", 1)[0], + trackext=track.name.rsplit(".", 1)[-1], + stem=args.stem, + ext=ext, + ) + stem.parent.mkdir(parents=True, exist_ok=True) + save_audio(res.pop(args.stem), str(stem), **kwargs) + # Warning : after poping the stem, selected stem is no longer in the dict 'res' + if args.other_method == "add": + other_stem = th.zeros_like(next(iter(res.values()))) + for i in res.values(): + other_stem += i + stem = out / args.filename.format( + track=track.name.rsplit(".", 1)[0], + trackext=track.name.rsplit(".", 1)[-1], + stem="no_" + args.stem, + ext=ext, + ) + stem.parent.mkdir(parents=True, exist_ok=True) + save_audio(other_stem, str(stem), **kwargs) + + +if __name__ == "__main__": + main() diff --git a/AIMeiSheng/demucs/solver.py b/AIMeiSheng/demucs/solver.py new file mode 100644 index 0000000..7c80b14 --- /dev/null +++ b/AIMeiSheng/demucs/solver.py @@ -0,0 +1,405 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Main training loop.""" + +import logging + +from dora import get_xp +from dora.utils import write_and_rename +from dora.log import LogProgress, bold +import torch +import torch.nn.functional as F + +from . import augment, distrib, states, pretrained +from .apply import apply_model +from .ema import ModelEMA +from .evaluate import evaluate, new_sdr +from .svd import svd_penalty +from .utils import pull_metric, EMA + +logger = logging.getLogger(__name__) + + +def _summary(metrics): + return " | ".join(f"{key.capitalize()}={val}" for key, val in metrics.items()) + + +class Solver(object): + def __init__(self, loaders, model, optimizer, args): + self.args = args + self.loaders = loaders + + self.model = model + self.optimizer = optimizer + self.quantizer = states.get_quantizer(self.model, args.quant, self.optimizer) + self.dmodel = distrib.wrap(model) + self.device = next(iter(self.model.parameters())).device + + # Exponential moving average of the model, either updated every batch or epoch. + # The best model from all the EMAs and the original one is kept based on the valid + # loss for the final best model. + self.emas = {'batch': [], 'epoch': []} + for kind in self.emas.keys(): + decays = getattr(args.ema, kind) + device = self.device if kind == 'batch' else 'cpu' + if decays: + for decay in decays: + self.emas[kind].append(ModelEMA(self.model, decay, device=device)) + + # data augment + augments = [augment.Shift(shift=int(args.dset.samplerate * args.dset.shift), + same=args.augment.shift_same)] + if args.augment.flip: + augments += [augment.FlipChannels(), augment.FlipSign()] + for aug in ['scale', 'remix']: + kw = getattr(args.augment, aug) + if kw.proba: + augments.append(getattr(augment, aug.capitalize())(**kw)) + self.augment = torch.nn.Sequential(*augments) + + xp = get_xp() + self.folder = xp.folder + # Checkpoints + self.checkpoint_file = xp.folder / 'checkpoint.th' + self.best_file = xp.folder / 'best.th' + logger.debug("Checkpoint will be saved to %s", self.checkpoint_file.resolve()) + self.best_state = None + self.best_changed = False + + self.link = xp.link + self.history = self.link.history + + self._reset() + + def _serialize(self, epoch): + package = {} + package['state'] = self.model.state_dict() + package['optimizer'] = self.optimizer.state_dict() + package['history'] = self.history + package['best_state'] = self.best_state + package['args'] = self.args + for kind, emas in self.emas.items(): + for k, ema in enumerate(emas): + package[f'ema_{kind}_{k}'] = ema.state_dict() + with write_and_rename(self.checkpoint_file) as tmp: + torch.save(package, tmp) + + save_every = self.args.save_every + if save_every and (epoch + 1) % save_every == 0 and epoch + 1 != self.args.epochs: + with write_and_rename(self.folder / f'checkpoint_{epoch + 1}.th') as tmp: + torch.save(package, tmp) + + if self.best_changed: + # Saving only the latest best model. + with write_and_rename(self.best_file) as tmp: + package = states.serialize_model(self.model, self.args) + package['state'] = self.best_state + torch.save(package, tmp) + self.best_changed = False + + def _reset(self): + """Reset state of the solver, potentially using checkpoint.""" + if self.checkpoint_file.exists(): + logger.info(f'Loading checkpoint model: {self.checkpoint_file}') + package = torch.load(self.checkpoint_file, 'cpu') + self.model.load_state_dict(package['state']) + self.optimizer.load_state_dict(package['optimizer']) + self.history[:] = package['history'] + self.best_state = package['best_state'] + for kind, emas in self.emas.items(): + for k, ema in enumerate(emas): + ema.load_state_dict(package[f'ema_{kind}_{k}']) + elif self.args.continue_pretrained: + model = pretrained.get_model( + name=self.args.continue_pretrained, + repo=self.args.pretrained_repo) + self.model.load_state_dict(model.state_dict()) + elif self.args.continue_from: + name = 'checkpoint.th' + root = self.folder.parent + cf = root / str(self.args.continue_from) / name + logger.info("Loading from %s", cf) + package = torch.load(cf, 'cpu') + self.best_state = package['best_state'] + if self.args.continue_best: + self.model.load_state_dict(package['best_state'], strict=False) + else: + self.model.load_state_dict(package['state'], strict=False) + if self.args.continue_opt: + self.optimizer.load_state_dict(package['optimizer']) + + def _format_train(self, metrics: dict) -> dict: + """Formatting for train/valid metrics.""" + losses = { + 'loss': format(metrics['loss'], ".4f"), + 'reco': format(metrics['reco'], ".4f"), + } + if 'nsdr' in metrics: + losses['nsdr'] = format(metrics['nsdr'], ".3f") + if self.quantizer is not None: + losses['ms'] = format(metrics['ms'], ".2f") + if 'grad' in metrics: + losses['grad'] = format(metrics['grad'], ".4f") + if 'best' in metrics: + losses['best'] = format(metrics['best'], '.4f') + if 'bname' in metrics: + losses['bname'] = metrics['bname'] + if 'penalty' in metrics: + losses['penalty'] = format(metrics['penalty'], ".4f") + if 'hloss' in metrics: + losses['hloss'] = format(metrics['hloss'], ".4f") + return losses + + def _format_test(self, metrics: dict) -> dict: + """Formatting for test metrics.""" + losses = {} + if 'sdr' in metrics: + losses['sdr'] = format(metrics['sdr'], '.3f') + if 'nsdr' in metrics: + losses['nsdr'] = format(metrics['nsdr'], '.3f') + for source in self.model.sources: + key = f'sdr_{source}' + if key in metrics: + losses[key] = format(metrics[key], '.3f') + key = f'nsdr_{source}' + if key in metrics: + losses[key] = format(metrics[key], '.3f') + return losses + + def train(self): + # Optimizing the model + if self.history: + logger.info("Replaying metrics from previous run") + for epoch, metrics in enumerate(self.history): + formatted = self._format_train(metrics['train']) + logger.info( + bold(f'Train Summary | Epoch {epoch + 1} | {_summary(formatted)}')) + formatted = self._format_train(metrics['valid']) + logger.info( + bold(f'Valid Summary | Epoch {epoch + 1} | {_summary(formatted)}')) + if 'test' in metrics: + formatted = self._format_test(metrics['test']) + if formatted: + logger.info(bold(f"Test Summary | Epoch {epoch + 1} | {_summary(formatted)}")) + + epoch = 0 + for epoch in range(len(self.history), self.args.epochs): + # Train one epoch + self.model.train() # Turn on BatchNorm & Dropout + metrics = {} + logger.info('-' * 70) + logger.info("Training...") + metrics['train'] = self._run_one_epoch(epoch) + formatted = self._format_train(metrics['train']) + logger.info( + bold(f'Train Summary | Epoch {epoch + 1} | {_summary(formatted)}')) + + # Cross validation + logger.info('-' * 70) + logger.info('Cross validation...') + self.model.eval() # Turn off Batchnorm & Dropout + with torch.no_grad(): + valid = self._run_one_epoch(epoch, train=False) + bvalid = valid + bname = 'main' + state = states.copy_state(self.model.state_dict()) + metrics['valid'] = {} + metrics['valid']['main'] = valid + key = self.args.test.metric + for kind, emas in self.emas.items(): + for k, ema in enumerate(emas): + with ema.swap(): + valid = self._run_one_epoch(epoch, train=False) + name = f'ema_{kind}_{k}' + metrics['valid'][name] = valid + a = valid[key] + b = bvalid[key] + if key.startswith('nsdr'): + a = -a + b = -b + if a < b: + bvalid = valid + state = ema.state + bname = name + metrics['valid'].update(bvalid) + metrics['valid']['bname'] = bname + + valid_loss = metrics['valid'][key] + mets = pull_metric(self.link.history, f'valid.{key}') + [valid_loss] + if key.startswith('nsdr'): + best_loss = max(mets) + else: + best_loss = min(mets) + metrics['valid']['best'] = best_loss + if self.args.svd.penalty > 0: + kw = dict(self.args.svd) + kw.pop('penalty') + with torch.no_grad(): + penalty = svd_penalty(self.model, exact=True, **kw) + metrics['valid']['penalty'] = penalty + + formatted = self._format_train(metrics['valid']) + logger.info( + bold(f'Valid Summary | Epoch {epoch + 1} | {_summary(formatted)}')) + + # Save the best model + if valid_loss == best_loss or self.args.dset.train_valid: + logger.info(bold('New best valid loss %.4f'), valid_loss) + self.best_state = states.copy_state(state) + self.best_changed = True + + # Eval model every `test.every` epoch or on last epoch + should_eval = (epoch + 1) % self.args.test.every == 0 + is_last = epoch == self.args.epochs - 1 + # # Tries to detect divergence in a reliable way and finish job + # # not to waste compute. + # # Commented out as this was super specific to the MDX competition. + # reco = metrics['valid']['main']['reco'] + # div = epoch >= 180 and reco > 0.18 + # div = div or epoch >= 100 and reco > 0.25 + # div = div and self.args.optim.loss == 'l1' + # if div: + # logger.warning("Finishing training early because valid loss is too high.") + # is_last = True + if should_eval or is_last: + # Evaluate on the testset + logger.info('-' * 70) + logger.info('Evaluating on the test set...') + # We switch to the best known model for testing + if self.args.test.best: + state = self.best_state + else: + state = states.copy_state(self.model.state_dict()) + compute_sdr = self.args.test.sdr and is_last + with states.swap_state(self.model, state): + with torch.no_grad(): + metrics['test'] = evaluate(self, compute_sdr=compute_sdr) + formatted = self._format_test(metrics['test']) + logger.info(bold(f"Test Summary | Epoch {epoch + 1} | {_summary(formatted)}")) + self.link.push_metrics(metrics) + + if distrib.rank == 0: + # Save model each epoch + self._serialize(epoch) + logger.debug("Checkpoint saved to %s", self.checkpoint_file.resolve()) + if is_last: + break + + def _run_one_epoch(self, epoch, train=True): + args = self.args + data_loader = self.loaders['train'] if train else self.loaders['valid'] + if distrib.world_size > 1 and train: + data_loader.sampler.set_epoch(epoch) + + label = ["Valid", "Train"][train] + name = label + f" | Epoch {epoch + 1}" + total = len(data_loader) + if args.max_batches: + total = min(total, args.max_batches) + logprog = LogProgress(logger, data_loader, total=total, + updates=self.args.misc.num_prints, name=name) + averager = EMA() + + for idx, sources in enumerate(logprog): + sources = sources.to(self.device) + if train: + sources = self.augment(sources) + mix = sources.sum(dim=1) + else: + mix = sources[:, 0] + sources = sources[:, 1:] + + if not train and self.args.valid_apply: + estimate = apply_model(self.model, mix, split=self.args.test.split, overlap=0) + else: + estimate = self.dmodel(mix) + if train and hasattr(self.model, 'transform_target'): + sources = self.model.transform_target(mix, sources) + assert estimate.shape == sources.shape, (estimate.shape, sources.shape) + dims = tuple(range(2, sources.dim())) + + if args.optim.loss == 'l1': + loss = F.l1_loss(estimate, sources, reduction='none') + loss = loss.mean(dims).mean(0) + reco = loss + elif args.optim.loss == 'mse': + loss = F.mse_loss(estimate, sources, reduction='none') + loss = loss.mean(dims) + reco = loss**0.5 + reco = reco.mean(0) + else: + raise ValueError(f"Invalid loss {self.args.loss}") + weights = torch.tensor(args.weights).to(sources) + loss = (loss * weights).sum() / weights.sum() + + ms = 0 + if self.quantizer is not None: + ms = self.quantizer.model_size() + if args.quant.diffq: + loss += args.quant.diffq * ms + + losses = {} + losses['reco'] = (reco * weights).sum() / weights.sum() + losses['ms'] = ms + + if not train: + nsdrs = new_sdr(sources, estimate.detach()).mean(0) + total = 0 + for source, nsdr, w in zip(self.model.sources, nsdrs, weights): + losses[f'nsdr_{source}'] = nsdr + total += w * nsdr + losses['nsdr'] = total / weights.sum() + + if train and args.svd.penalty > 0: + kw = dict(args.svd) + kw.pop('penalty') + penalty = svd_penalty(self.model, **kw) + losses['penalty'] = penalty + loss += args.svd.penalty * penalty + + losses['loss'] = loss + + for k, source in enumerate(self.model.sources): + losses[f'reco_{source}'] = reco[k] + + # optimize model in training mode + if train: + loss.backward() + grad_norm = 0 + grads = [] + for p in self.model.parameters(): + if p.grad is not None: + grad_norm += p.grad.data.norm()**2 + grads.append(p.grad.data) + losses['grad'] = grad_norm ** 0.5 + if args.optim.clip_grad: + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), + args.optim.clip_grad) + + if self.args.flag == 'uns': + for n, p in self.model.named_parameters(): + if p.grad is None: + print('no grad', n) + self.optimizer.step() + self.optimizer.zero_grad() + for ema in self.emas['batch']: + ema.update() + losses = averager(losses) + logs = self._format_train(losses) + logprog.update(**logs) + # Just in case, clear some memory + del loss, estimate, reco, ms + if args.max_batches == idx: + break + if self.args.debug and train: + break + if self.args.flag == 'debug': + break + if train: + for ema in self.emas['epoch']: + ema.update() + return distrib.average(losses, idx + 1) diff --git a/AIMeiSheng/demucs/spec.py b/AIMeiSheng/demucs/spec.py new file mode 100644 index 0000000..2925045 --- /dev/null +++ b/AIMeiSheng/demucs/spec.py @@ -0,0 +1,47 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Conveniance wrapper to perform STFT and iSTFT""" + +import torch as th + + +def spectro(x, n_fft=512, hop_length=None, pad=0): + *other, length = x.shape + x = x.reshape(-1, length) + is_mps = x.device.type == 'mps' + if is_mps: + x = x.cpu() + z = th.stft(x, + n_fft * (1 + pad), + hop_length or n_fft // 4, + window=th.hann_window(n_fft).to(x), + win_length=n_fft, + normalized=True, + center=True, + return_complex=True, + pad_mode='reflect') + _, freqs, frame = z.shape + return z.view(*other, freqs, frame) + + +def ispectro(z, hop_length=None, length=None, pad=0): + *other, freqs, frames = z.shape + n_fft = 2 * freqs - 2 + z = z.view(-1, freqs, frames) + win_length = n_fft // (1 + pad) + is_mps = z.device.type == 'mps' + if is_mps: + z = z.cpu() + x = th.istft(z, + n_fft, + hop_length, + window=th.hann_window(win_length).to(z.real), + win_length=win_length, + normalized=True, + length=length, + center=True) + _, length = x.shape + return x.view(*other, length) diff --git a/AIMeiSheng/demucs/states.py b/AIMeiSheng/demucs/states.py new file mode 100644 index 0000000..361bb41 --- /dev/null +++ b/AIMeiSheng/demucs/states.py @@ -0,0 +1,163 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +""" +Utilities to save and load models. +""" +from contextlib import contextmanager + +import functools +import hashlib +import inspect +import io +from pathlib import Path +import warnings + +from omegaconf import OmegaConf +from dora.log import fatal +import torch + + +def _check_diffq(): + try: + import diffq # noqa + except ImportError: + fatal('Trying to use DiffQ, but diffq is not installed.\n' + 'On Windows run: python.exe -m pip install diffq \n' + 'On Linux/Mac, run: python3 -m pip install diffq') + + +def get_quantizer(model, args, optimizer=None): + """Return the quantizer given the XP quantization args.""" + quantizer = None + if args.diffq: + _check_diffq() + from diffq import DiffQuantizer + quantizer = DiffQuantizer( + model, min_size=args.min_size, group_size=args.group_size) + if optimizer is not None: + quantizer.setup_optimizer(optimizer) + elif args.qat: + _check_diffq() + from diffq import UniformQuantizer + quantizer = UniformQuantizer( + model, bits=args.qat, min_size=args.min_size) + return quantizer + + +def load_model(path_or_package, strict=False): + """Load a model from the given serialized model, either given as a dict (already loaded) + or a path to a file on disk.""" + if isinstance(path_or_package, dict): + package = path_or_package + elif isinstance(path_or_package, (str, Path)): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + path = path_or_package + package = torch.load(path, 'cpu') + else: + raise ValueError(f"Invalid type for {path_or_package}.") + + klass = package["klass"] + args = package["args"] + kwargs = package["kwargs"] + + if strict: + model = klass(*args, **kwargs) + else: + sig = inspect.signature(klass) + for key in list(kwargs): + if key not in sig.parameters: + warnings.warn("Dropping inexistant parameter " + key) + del kwargs[key] + model = klass(*args, **kwargs) + + state = package["state"] + + set_state(model, state) + return model + + +def get_state(model, quantizer, half=False): + """Get the state from a model, potentially with quantization applied. + If `half` is True, model are stored as half precision, which shouldn't impact performance + but half the state size.""" + if quantizer is None: + dtype = torch.half if half else None + state = {k: p.data.to(device='cpu', dtype=dtype) for k, p in model.state_dict().items()} + else: + state = quantizer.get_quantized_state() + state['__quantized'] = True + return state + + +def set_state(model, state, quantizer=None): + """Set the state on a given model.""" + if state.get('__quantized'): + if quantizer is not None: + quantizer.restore_quantized_state(model, state['quantized']) + else: + _check_diffq() + from diffq import restore_quantized_state + restore_quantized_state(model, state) + else: + model.load_state_dict(state) + return state + + +def save_with_checksum(content, path): + """Save the given value on disk, along with a sha256 hash. + Should be used with the output of either `serialize_model` or `get_state`.""" + buf = io.BytesIO() + torch.save(content, buf) + sig = hashlib.sha256(buf.getvalue()).hexdigest()[:8] + + path = path.parent / (path.stem + "-" + sig + path.suffix) + path.write_bytes(buf.getvalue()) + + +def serialize_model(model, training_args, quantizer=None, half=True): + args, kwargs = model._init_args_kwargs + klass = model.__class__ + + state = get_state(model, quantizer, half) + return { + 'klass': klass, + 'args': args, + 'kwargs': kwargs, + 'state': state, + 'training_args': OmegaConf.to_container(training_args, resolve=True), + } + + +def copy_state(state): + return {k: v.cpu().clone() for k, v in state.items()} + + +@contextmanager +def swap_state(model, state): + """ + Context manager that swaps the state of a model, e.g: + + # model is in old state + with swap_state(model, new_state): + # model in new state + # model back to old state + """ + old_state = copy_state(model.state_dict()) + model.load_state_dict(state, strict=False) + try: + yield + finally: + model.load_state_dict(old_state) + + +def capture_init(init): + @functools.wraps(init) + def __init__(self, *args, **kwargs): + self._init_args_kwargs = (args, kwargs) + init(self, *args, **kwargs) + + return __init__ diff --git a/AIMeiSheng/demucs/svd.py b/AIMeiSheng/demucs/svd.py new file mode 100644 index 0000000..1cbaa82 --- /dev/null +++ b/AIMeiSheng/demucs/svd.py @@ -0,0 +1,83 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Ways to make the model stronger.""" +import random +import torch + + +def power_iteration(m, niters=1, bs=1): + """This is the power method. batch size is used to try multiple starting point in parallel.""" + assert m.dim() == 2 + assert m.shape[0] == m.shape[1] + dim = m.shape[0] + b = torch.randn(dim, bs, device=m.device, dtype=m.dtype) + + for _ in range(niters): + n = m.mm(b) + norm = n.norm(dim=0, keepdim=True) + b = n / (1e-10 + norm) + + return norm.mean() + + +# We need a shared RNG to make sure all the distributed worker will skip the penalty together, +# as otherwise we wouldn't get any speed up. +penalty_rng = random.Random(1234) + + +def svd_penalty(model, min_size=0.1, dim=1, niters=2, powm=False, convtr=True, + proba=1, conv_only=False, exact=False, bs=1): + """ + Penalty on the largest singular value for a layer. + Args: + - model: model to penalize + - min_size: minimum size in MB of a layer to penalize. + - dim: projection dimension for the svd_lowrank. Higher is better but slower. + - niters: number of iterations in the algorithm used by svd_lowrank. + - powm: use power method instead of lowrank SVD, my own experience + is that it is both slower and less stable. + - convtr: when True, differentiate between Conv and Transposed Conv. + this is kept for compatibility with older experiments. + - proba: probability to apply the penalty. + - conv_only: only apply to conv and conv transposed, not LSTM + (might not be reliable for other models than Demucs). + - exact: use exact SVD (slow but useful at validation). + - bs: batch_size for power method. + """ + total = 0 + if penalty_rng.random() > proba: + return 0. + + for m in model.modules(): + for name, p in m.named_parameters(recurse=False): + if p.numel() / 2**18 < min_size: + continue + if convtr: + if isinstance(m, (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d)): + if p.dim() in [3, 4]: + p = p.transpose(0, 1).contiguous() + if p.dim() == 3: + p = p.view(len(p), -1) + elif p.dim() == 4: + p = p.view(len(p), -1) + elif p.dim() == 1: + continue + elif conv_only: + continue + assert p.dim() == 2, (name, p.shape) + if exact: + estimate = torch.svd(p, compute_uv=False)[1].pow(2).max() + elif powm: + a, b = p.shape + if a < b: + n = p.mm(p.t()) + else: + n = p.t().mm(p) + estimate = power_iteration(n, niters, bs) + else: + estimate = torch.svd_lowrank(p, dim, niters)[1][0].pow(2) + total += estimate + return total / proba diff --git a/AIMeiSheng/demucs/train.py b/AIMeiSheng/demucs/train.py new file mode 100644 index 0000000..9aa7b64 --- /dev/null +++ b/AIMeiSheng/demucs/train.py @@ -0,0 +1,251 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Main training script entry point""" + +import logging +import os +from pathlib import Path +import sys + +from dora import hydra_main +import hydra +from hydra.core.global_hydra import GlobalHydra +from omegaconf import OmegaConf +import torch +from torch import nn +import torchaudio +from torch.utils.data import ConcatDataset + +from . import distrib +from .wav import get_wav_datasets, get_musdb_wav_datasets +from .demucs import Demucs +from .hdemucs import HDemucs +from .htdemucs import HTDemucs +from .repitch import RepitchedWrapper +from .solver import Solver +from .states import capture_init +from .utils import random_subset + +logger = logging.getLogger(__name__) + + +class TorchHDemucsWrapper(nn.Module): + """Wrapper around torchaudio HDemucs implementation to provide the proper metadata + for model evaluation. + See https://pytorch.org/audio/stable/tutorials/hybrid_demucs_tutorial.html""" + + @capture_init + def __init__(self, **kwargs): + super().__init__() + try: + from torchaudio.models import HDemucs as TorchHDemucs + except ImportError: + raise ImportError("Please upgrade torchaudio for using its implementation of HDemucs") + self.samplerate = kwargs.pop('samplerate') + self.segment = kwargs.pop('segment') + self.sources = kwargs['sources'] + self.torch_hdemucs = TorchHDemucs(**kwargs) + + def forward(self, mix): + return self.torch_hdemucs.forward(mix) + + +def get_model(args): + extra = { + 'sources': list(args.dset.sources), + 'audio_channels': args.dset.channels, + 'samplerate': args.dset.samplerate, + 'segment': args.model_segment or 4 * args.dset.segment, + } + klass = { + 'demucs': Demucs, + 'hdemucs': HDemucs, + 'htdemucs': HTDemucs, + 'torch_hdemucs': TorchHDemucsWrapper, + }[args.model] + kw = OmegaConf.to_container(getattr(args, args.model), resolve=True) + model = klass(**extra, **kw) + return model + + +def get_optimizer(model, args): + seen_params = set() + other_params = [] + groups = [] + for n, module in model.named_modules(): + if hasattr(module, "make_optim_group"): + group = module.make_optim_group() + params = set(group["params"]) + assert params.isdisjoint(seen_params) + seen_params |= set(params) + groups.append(group) + for param in model.parameters(): + if param not in seen_params: + other_params.append(param) + groups.insert(0, {"params": other_params}) + parameters = groups + if args.optim.optim == "adam": + return torch.optim.Adam( + parameters, + lr=args.optim.lr, + betas=(args.optim.momentum, args.optim.beta2), + weight_decay=args.optim.weight_decay, + ) + elif args.optim.optim == "adamw": + return torch.optim.AdamW( + parameters, + lr=args.optim.lr, + betas=(args.optim.momentum, args.optim.beta2), + weight_decay=args.optim.weight_decay, + ) + else: + raise ValueError("Invalid optimizer %s", args.optim.optimizer) + + +def get_datasets(args): + if args.dset.backend: + torchaudio.set_audio_backend(args.dset.backend) + if args.dset.use_musdb: + train_set, valid_set = get_musdb_wav_datasets(args.dset) + else: + train_set, valid_set = [], [] + if args.dset.wav: + extra_train_set, extra_valid_set = get_wav_datasets(args.dset) + if len(args.dset.sources) <= 4: + train_set = ConcatDataset([train_set, extra_train_set]) + valid_set = ConcatDataset([valid_set, extra_valid_set]) + else: + train_set = extra_train_set + valid_set = extra_valid_set + + if args.dset.wav2: + extra_train_set, extra_valid_set = get_wav_datasets(args.dset, "wav2") + weight = args.dset.wav2_weight + if weight is not None: + b = len(train_set) + e = len(extra_train_set) + reps = max(1, round(e / b * (1 / weight - 1))) + else: + reps = 1 + train_set = ConcatDataset([train_set] * reps + [extra_train_set]) + if args.dset.wav2_valid: + if weight is not None: + b = len(valid_set) + n_kept = int(round(weight * b / (1 - weight))) + valid_set = ConcatDataset( + [valid_set, random_subset(extra_valid_set, n_kept)] + ) + else: + valid_set = ConcatDataset([valid_set, extra_valid_set]) + if args.dset.valid_samples is not None: + valid_set = random_subset(valid_set, args.dset.valid_samples) + assert len(train_set) + assert len(valid_set) + return train_set, valid_set + + +def get_solver(args, model_only=False): + distrib.init() + + torch.manual_seed(args.seed) + model = get_model(args) + if args.misc.show: + logger.info(model) + mb = sum(p.numel() for p in model.parameters()) * 4 / 2**20 + logger.info('Size: %.1f MB', mb) + if hasattr(model, 'valid_length'): + field = model.valid_length(1) + logger.info('Field: %.1f ms', field / args.dset.samplerate * 1000) + sys.exit(0) + + # torch also initialize cuda seed if available + if torch.cuda.is_available(): + model.cuda() + + # optimizer + optimizer = get_optimizer(model, args) + + assert args.batch_size % distrib.world_size == 0 + args.batch_size //= distrib.world_size + + if model_only: + return Solver(None, model, optimizer, args) + + train_set, valid_set = get_datasets(args) + + if args.augment.repitch.proba: + vocals = [] + if 'vocals' in args.dset.sources: + vocals.append(args.dset.sources.index('vocals')) + else: + logger.warning('No vocal source found') + if args.augment.repitch.proba: + train_set = RepitchedWrapper(train_set, vocals=vocals, **args.augment.repitch) + + logger.info("train/valid set size: %d %d", len(train_set), len(valid_set)) + train_loader = distrib.loader( + train_set, batch_size=args.batch_size, shuffle=True, + num_workers=args.misc.num_workers, drop_last=True) + if args.dset.full_cv: + valid_loader = distrib.loader( + valid_set, batch_size=1, shuffle=False, + num_workers=args.misc.num_workers) + else: + valid_loader = distrib.loader( + valid_set, batch_size=args.batch_size, shuffle=False, + num_workers=args.misc.num_workers, drop_last=True) + loaders = {"train": train_loader, "valid": valid_loader} + + # Construct Solver + return Solver(loaders, model, optimizer, args) + + +def get_solver_from_sig(sig, model_only=False): + inst = GlobalHydra.instance() + hyd = None + if inst.is_initialized(): + hyd = inst.hydra + inst.clear() + xp = main.get_xp_from_sig(sig) + if hyd is not None: + inst.clear() + inst.initialize(hyd) + + with xp.enter(stack=True): + return get_solver(xp.cfg, model_only) + + +@hydra_main(config_path="../conf", config_name="config", version_base="1.1") +def main(args): + global __file__ + __file__ = hydra.utils.to_absolute_path(__file__) + for attr in ["musdb", "wav", "metadata"]: + val = getattr(args.dset, attr) + if val is not None: + setattr(args.dset, attr, hydra.utils.to_absolute_path(val)) + + os.environ["OMP_NUM_THREADS"] = "1" + os.environ["MKL_NUM_THREADS"] = "1" + + if args.misc.verbose: + logger.setLevel(logging.DEBUG) + + logger.info("For logs, checkpoints and samples check %s", os.getcwd()) + logger.debug(args) + from dora import get_xp + logger.debug(get_xp().cfg) + + solver = get_solver(args) + solver.train() + + +if '_DORA_TEST_PATH' in os.environ: + main.dora.dir = Path(os.environ['_DORA_TEST_PATH']) + + +if __name__ == "__main__": + main() diff --git a/AIMeiSheng/demucs/transformer.py b/AIMeiSheng/demucs/transformer.py new file mode 100644 index 0000000..56a465b --- /dev/null +++ b/AIMeiSheng/demucs/transformer.py @@ -0,0 +1,839 @@ +# Copyright (c) 2019-present, Meta, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# First author is Simon Rouard. + +import random +import typing as tp + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import math +from einops import rearrange + + +def create_sin_embedding( + length: int, dim: int, shift: int = 0, device="cpu", max_period=10000 +): + # We aim for TBC format + assert dim % 2 == 0 + pos = shift + torch.arange(length, device=device).view(-1, 1, 1) + half_dim = dim // 2 + adim = torch.arange(dim // 2, device=device).view(1, 1, -1) + phase = pos / (max_period ** (adim / (half_dim - 1))) + return torch.cat( + [ + torch.cos(phase), + torch.sin(phase), + ], + dim=-1, + ) + + +def create_2d_sin_embedding(d_model, height, width, device="cpu", max_period=10000): + """ + :param d_model: dimension of the model + :param height: height of the positions + :param width: width of the positions + :return: d_model*height*width position matrix + """ + if d_model % 4 != 0: + raise ValueError( + "Cannot use sin/cos positional encoding with " + "odd dimension (got dim={:d})".format(d_model) + ) + pe = torch.zeros(d_model, height, width) + # Each dimension use half of d_model + d_model = int(d_model / 2) + div_term = torch.exp( + torch.arange(0.0, d_model, 2) * -(math.log(max_period) / d_model) + ) + pos_w = torch.arange(0.0, width).unsqueeze(1) + pos_h = torch.arange(0.0, height).unsqueeze(1) + pe[0:d_model:2, :, :] = ( + torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) + ) + pe[1:d_model:2, :, :] = ( + torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) + ) + pe[d_model::2, :, :] = ( + torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) + ) + pe[d_model + 1:: 2, :, :] = ( + torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) + ) + + return pe[None, :].to(device) + + +def create_sin_embedding_cape( + length: int, + dim: int, + batch_size: int, + mean_normalize: bool, + augment: bool, # True during training + max_global_shift: float = 0.0, # delta max + max_local_shift: float = 0.0, # epsilon max + max_scale: float = 1.0, + device: str = "cpu", + max_period: float = 10000.0, +): + # We aim for TBC format + assert dim % 2 == 0 + pos = 1.0 * torch.arange(length).view(-1, 1, 1) # (length, 1, 1) + pos = pos.repeat(1, batch_size, 1) # (length, batch_size, 1) + if mean_normalize: + pos -= torch.nanmean(pos, dim=0, keepdim=True) + + if augment: + delta = np.random.uniform( + -max_global_shift, +max_global_shift, size=[1, batch_size, 1] + ) + delta_local = np.random.uniform( + -max_local_shift, +max_local_shift, size=[length, batch_size, 1] + ) + log_lambdas = np.random.uniform( + -np.log(max_scale), +np.log(max_scale), size=[1, batch_size, 1] + ) + pos = (pos + delta + delta_local) * np.exp(log_lambdas) + + pos = pos.to(device) + + half_dim = dim // 2 + adim = torch.arange(dim // 2, device=device).view(1, 1, -1) + phase = pos / (max_period ** (adim / (half_dim - 1))) + return torch.cat( + [ + torch.cos(phase), + torch.sin(phase), + ], + dim=-1, + ).float() + + +def get_causal_mask(length): + pos = torch.arange(length) + return pos > pos[:, None] + + +def get_elementary_mask( + T1, + T2, + mask_type, + sparse_attn_window, + global_window, + mask_random_seed, + sparsity, + device, +): + """ + When the input of the Decoder has length T1 and the output T2 + The mask matrix has shape (T2, T1) + """ + assert mask_type in ["diag", "jmask", "random", "global"] + + if mask_type == "global": + mask = torch.zeros(T2, T1, dtype=torch.bool) + mask[:, :global_window] = True + line_window = int(global_window * T2 / T1) + mask[:line_window, :] = True + + if mask_type == "diag": + + mask = torch.zeros(T2, T1, dtype=torch.bool) + rows = torch.arange(T2)[:, None] + cols = ( + (T1 / T2 * rows + torch.arange(-sparse_attn_window, sparse_attn_window + 1)) + .long() + .clamp(0, T1 - 1) + ) + mask.scatter_(1, cols, torch.ones(1, dtype=torch.bool).expand_as(cols)) + + elif mask_type == "jmask": + mask = torch.zeros(T2 + 2, T1 + 2, dtype=torch.bool) + rows = torch.arange(T2 + 2)[:, None] + t = torch.arange(0, int((2 * T1) ** 0.5 + 1)) + t = (t * (t + 1) / 2).int() + t = torch.cat([-t.flip(0)[:-1], t]) + cols = (T1 / T2 * rows + t).long().clamp(0, T1 + 1) + mask.scatter_(1, cols, torch.ones(1, dtype=torch.bool).expand_as(cols)) + mask = mask[1:-1, 1:-1] + + elif mask_type == "random": + gene = torch.Generator(device=device) + gene.manual_seed(mask_random_seed) + mask = ( + torch.rand(T1 * T2, generator=gene, device=device).reshape(T2, T1) + > sparsity + ) + + mask = mask.to(device) + return mask + + +def get_mask( + T1, + T2, + mask_type, + sparse_attn_window, + global_window, + mask_random_seed, + sparsity, + device, +): + """ + Return a SparseCSRTensor mask that is a combination of elementary masks + mask_type can be a combination of multiple masks: for instance "diag_jmask_random" + """ + from xformers.sparse import SparseCSRTensor + # create a list + mask_types = mask_type.split("_") + + all_masks = [ + get_elementary_mask( + T1, + T2, + mask, + sparse_attn_window, + global_window, + mask_random_seed, + sparsity, + device, + ) + for mask in mask_types + ] + + final_mask = torch.stack(all_masks).sum(axis=0) > 0 + + return SparseCSRTensor.from_dense(final_mask[None]) + + +class ScaledEmbedding(nn.Module): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + scale: float = 1.0, + boost: float = 3.0, + ): + super().__init__() + self.embedding = nn.Embedding(num_embeddings, embedding_dim) + self.embedding.weight.data *= scale / boost + self.boost = boost + + @property + def weight(self): + return self.embedding.weight * self.boost + + def forward(self, x): + return self.embedding(x) * self.boost + + +class LayerScale(nn.Module): + """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf). + This rescales diagonaly residual outputs close to 0 initially, then learnt. + """ + + def __init__(self, channels: int, init: float = 0, channel_last=False): + """ + channel_last = False corresponds to (B, C, T) tensors + channel_last = True corresponds to (T, B, C) tensors + """ + super().__init__() + self.channel_last = channel_last + self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True)) + self.scale.data[:] = init + + def forward(self, x): + if self.channel_last: + return self.scale * x + else: + return self.scale[:, None] * x + + +class MyGroupNorm(nn.GroupNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x): + """ + x: (B, T, C) + if num_groups=1: Normalisation on all T and C together for each B + """ + x = x.transpose(1, 2) + return super().forward(x).transpose(1, 2) + + +class MyTransformerEncoderLayer(nn.TransformerEncoderLayer): + def __init__( + self, + d_model, + nhead, + dim_feedforward=2048, + dropout=0.1, + activation=F.relu, + group_norm=0, + norm_first=False, + norm_out=False, + layer_norm_eps=1e-5, + layer_scale=False, + init_values=1e-4, + device=None, + dtype=None, + sparse=False, + mask_type="diag", + mask_random_seed=42, + sparse_attn_window=500, + global_window=50, + auto_sparsity=False, + sparsity=0.95, + batch_first=False, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + activation=activation, + layer_norm_eps=layer_norm_eps, + batch_first=batch_first, + norm_first=norm_first, + device=device, + dtype=dtype, + ) + self.sparse = sparse + self.auto_sparsity = auto_sparsity + if sparse: + if not auto_sparsity: + self.mask_type = mask_type + self.sparse_attn_window = sparse_attn_window + self.global_window = global_window + self.sparsity = sparsity + if group_norm: + self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs) + self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs) + + self.norm_out = None + if self.norm_first & norm_out: + self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model) + self.gamma_1 = ( + LayerScale(d_model, init_values, True) if layer_scale else nn.Identity() + ) + self.gamma_2 = ( + LayerScale(d_model, init_values, True) if layer_scale else nn.Identity() + ) + + if sparse: + self.self_attn = MultiheadAttention( + d_model, nhead, dropout=dropout, batch_first=batch_first, + auto_sparsity=sparsity if auto_sparsity else 0, + ) + self.__setattr__("src_mask", torch.zeros(1, 1)) + self.mask_random_seed = mask_random_seed + + def forward(self, src, src_mask=None, src_key_padding_mask=None): + """ + if batch_first = False, src shape is (T, B, C) + the case where batch_first=True is not covered + """ + device = src.device + x = src + T, B, C = x.shape + if self.sparse and not self.auto_sparsity: + assert src_mask is None + src_mask = self.src_mask + if src_mask.shape[-1] != T: + src_mask = get_mask( + T, + T, + self.mask_type, + self.sparse_attn_window, + self.global_window, + self.mask_random_seed, + self.sparsity, + device, + ) + self.__setattr__("src_mask", src_mask) + + if self.norm_first: + x = x + self.gamma_1( + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask) + ) + x = x + self.gamma_2(self._ff_block(self.norm2(x))) + + if self.norm_out: + x = self.norm_out(x) + else: + x = self.norm1( + x + self.gamma_1(self._sa_block(x, src_mask, src_key_padding_mask)) + ) + x = self.norm2(x + self.gamma_2(self._ff_block(x))) + + return x + + +class CrossTransformerEncoderLayer(nn.Module): + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation=F.relu, + layer_norm_eps: float = 1e-5, + layer_scale: bool = False, + init_values: float = 1e-4, + norm_first: bool = False, + group_norm: bool = False, + norm_out: bool = False, + sparse=False, + mask_type="diag", + mask_random_seed=42, + sparse_attn_window=500, + global_window=50, + sparsity=0.95, + auto_sparsity=None, + device=None, + dtype=None, + batch_first=False, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + self.sparse = sparse + self.auto_sparsity = auto_sparsity + if sparse: + if not auto_sparsity: + self.mask_type = mask_type + self.sparse_attn_window = sparse_attn_window + self.global_window = global_window + self.sparsity = sparsity + + self.cross_attn: nn.Module + self.cross_attn = nn.MultiheadAttention( + d_model, nhead, dropout=dropout, batch_first=batch_first) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) + + self.norm_first = norm_first + self.norm1: nn.Module + self.norm2: nn.Module + self.norm3: nn.Module + if group_norm: + self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs) + self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs) + self.norm3 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs) + else: + self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) + self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) + self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) + + self.norm_out = None + if self.norm_first & norm_out: + self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model) + + self.gamma_1 = ( + LayerScale(d_model, init_values, True) if layer_scale else nn.Identity() + ) + self.gamma_2 = ( + LayerScale(d_model, init_values, True) if layer_scale else nn.Identity() + ) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + # Legacy string support for activation function. + if isinstance(activation, str): + self.activation = self._get_activation_fn(activation) + else: + self.activation = activation + + if sparse: + self.cross_attn = MultiheadAttention( + d_model, nhead, dropout=dropout, batch_first=batch_first, + auto_sparsity=sparsity if auto_sparsity else 0) + if not auto_sparsity: + self.__setattr__("mask", torch.zeros(1, 1)) + self.mask_random_seed = mask_random_seed + + def forward(self, q, k, mask=None): + """ + Args: + q: tensor of shape (T, B, C) + k: tensor of shape (S, B, C) + mask: tensor of shape (T, S) + + """ + device = q.device + T, B, C = q.shape + S, B, C = k.shape + if self.sparse and not self.auto_sparsity: + assert mask is None + mask = self.mask + if mask.shape[-1] != S or mask.shape[-2] != T: + mask = get_mask( + S, + T, + self.mask_type, + self.sparse_attn_window, + self.global_window, + self.mask_random_seed, + self.sparsity, + device, + ) + self.__setattr__("mask", mask) + + if self.norm_first: + x = q + self.gamma_1(self._ca_block(self.norm1(q), self.norm2(k), mask)) + x = x + self.gamma_2(self._ff_block(self.norm3(x))) + if self.norm_out: + x = self.norm_out(x) + else: + x = self.norm1(q + self.gamma_1(self._ca_block(q, k, mask))) + x = self.norm2(x + self.gamma_2(self._ff_block(x))) + + return x + + # self-attention block + def _ca_block(self, q, k, attn_mask=None): + x = self.cross_attn(q, k, k, attn_mask=attn_mask, need_weights=False)[0] + return self.dropout1(x) + + # feed forward block + def _ff_block(self, x): + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + return self.dropout2(x) + + def _get_activation_fn(self, activation): + if activation == "relu": + return F.relu + elif activation == "gelu": + return F.gelu + + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) + + +# ----------------- MULTI-BLOCKS MODELS: ----------------------- + + +class CrossTransformerEncoder(nn.Module): + def __init__( + self, + dim: int, + emb: str = "sin", + hidden_scale: float = 4.0, + num_heads: int = 8, + num_layers: int = 6, + cross_first: bool = False, + dropout: float = 0.0, + max_positions: int = 1000, + norm_in: bool = True, + norm_in_group: bool = False, + group_norm: int = False, + norm_first: bool = False, + norm_out: bool = False, + max_period: float = 10000.0, + weight_decay: float = 0.0, + lr: tp.Optional[float] = None, + layer_scale: bool = False, + gelu: bool = True, + sin_random_shift: int = 0, + weight_pos_embed: float = 1.0, + cape_mean_normalize: bool = True, + cape_augment: bool = True, + cape_glob_loc_scale: list = [5000.0, 1.0, 1.4], + sparse_self_attn: bool = False, + sparse_cross_attn: bool = False, + mask_type: str = "diag", + mask_random_seed: int = 42, + sparse_attn_window: int = 500, + global_window: int = 50, + auto_sparsity: bool = False, + sparsity: float = 0.95, + ): + super().__init__() + """ + """ + assert dim % num_heads == 0 + + hidden_dim = int(dim * hidden_scale) + + self.num_layers = num_layers + # classic parity = 1 means that if idx%2 == 1 there is a + # classical encoder else there is a cross encoder + self.classic_parity = 1 if cross_first else 0 + self.emb = emb + self.max_period = max_period + self.weight_decay = weight_decay + self.weight_pos_embed = weight_pos_embed + self.sin_random_shift = sin_random_shift + if emb == "cape": + self.cape_mean_normalize = cape_mean_normalize + self.cape_augment = cape_augment + self.cape_glob_loc_scale = cape_glob_loc_scale + if emb == "scaled": + self.position_embeddings = ScaledEmbedding(max_positions, dim, scale=0.2) + + self.lr = lr + + activation: tp.Any = F.gelu if gelu else F.relu + + self.norm_in: nn.Module + self.norm_in_t: nn.Module + if norm_in: + self.norm_in = nn.LayerNorm(dim) + self.norm_in_t = nn.LayerNorm(dim) + elif norm_in_group: + self.norm_in = MyGroupNorm(int(norm_in_group), dim) + self.norm_in_t = MyGroupNorm(int(norm_in_group), dim) + else: + self.norm_in = nn.Identity() + self.norm_in_t = nn.Identity() + + # spectrogram layers + self.layers = nn.ModuleList() + # temporal layers + self.layers_t = nn.ModuleList() + + kwargs_common = { + "d_model": dim, + "nhead": num_heads, + "dim_feedforward": hidden_dim, + "dropout": dropout, + "activation": activation, + "group_norm": group_norm, + "norm_first": norm_first, + "norm_out": norm_out, + "layer_scale": layer_scale, + "mask_type": mask_type, + "mask_random_seed": mask_random_seed, + "sparse_attn_window": sparse_attn_window, + "global_window": global_window, + "sparsity": sparsity, + "auto_sparsity": auto_sparsity, + "batch_first": True, + } + + kwargs_classic_encoder = dict(kwargs_common) + kwargs_classic_encoder.update({ + "sparse": sparse_self_attn, + }) + kwargs_cross_encoder = dict(kwargs_common) + kwargs_cross_encoder.update({ + "sparse": sparse_cross_attn, + }) + + for idx in range(num_layers): + if idx % 2 == self.classic_parity: + + self.layers.append(MyTransformerEncoderLayer(**kwargs_classic_encoder)) + self.layers_t.append( + MyTransformerEncoderLayer(**kwargs_classic_encoder) + ) + + else: + self.layers.append(CrossTransformerEncoderLayer(**kwargs_cross_encoder)) + + self.layers_t.append( + CrossTransformerEncoderLayer(**kwargs_cross_encoder) + ) + + def forward(self, x, xt): + B, C, Fr, T1 = x.shape + pos_emb_2d = create_2d_sin_embedding( + C, Fr, T1, x.device, self.max_period + ) # (1, C, Fr, T1) + pos_emb_2d = rearrange(pos_emb_2d, "b c fr t1 -> b (t1 fr) c") + x = rearrange(x, "b c fr t1 -> b (t1 fr) c") + x = self.norm_in(x) + x = x + self.weight_pos_embed * pos_emb_2d + + B, C, T2 = xt.shape + xt = rearrange(xt, "b c t2 -> b t2 c") # now T2, B, C + pos_emb = self._get_pos_embedding(T2, B, C, x.device) + pos_emb = rearrange(pos_emb, "t2 b c -> b t2 c") + xt = self.norm_in_t(xt) + xt = xt + self.weight_pos_embed * pos_emb + + for idx in range(self.num_layers): + if idx % 2 == self.classic_parity: + x = self.layers[idx](x) + xt = self.layers_t[idx](xt) + else: + old_x = x + x = self.layers[idx](x, xt) + xt = self.layers_t[idx](xt, old_x) + + x = rearrange(x, "b (t1 fr) c -> b c fr t1", t1=T1) + xt = rearrange(xt, "b t2 c -> b c t2") + return x, xt + + def _get_pos_embedding(self, T, B, C, device): + if self.emb == "sin": + shift = random.randrange(self.sin_random_shift + 1) + pos_emb = create_sin_embedding( + T, C, shift=shift, device=device, max_period=self.max_period + ) + elif self.emb == "cape": + if self.training: + pos_emb = create_sin_embedding_cape( + T, + C, + B, + device=device, + max_period=self.max_period, + mean_normalize=self.cape_mean_normalize, + augment=self.cape_augment, + max_global_shift=self.cape_glob_loc_scale[0], + max_local_shift=self.cape_glob_loc_scale[1], + max_scale=self.cape_glob_loc_scale[2], + ) + else: + pos_emb = create_sin_embedding_cape( + T, + C, + B, + device=device, + max_period=self.max_period, + mean_normalize=self.cape_mean_normalize, + augment=False, + ) + + elif self.emb == "scaled": + pos = torch.arange(T, device=device) + pos_emb = self.position_embeddings(pos)[:, None] + + return pos_emb + + def make_optim_group(self): + group = {"params": list(self.parameters()), "weight_decay": self.weight_decay} + if self.lr is not None: + group["lr"] = self.lr + return group + + +# Attention Modules + + +class MultiheadAttention(nn.Module): + def __init__( + self, + embed_dim, + num_heads, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + kdim=None, + vdim=None, + batch_first=False, + auto_sparsity=None, + ): + super().__init__() + assert auto_sparsity is not None, "sanity check" + self.num_heads = num_heads + self.q = torch.nn.Linear(embed_dim, embed_dim, bias=bias) + self.k = torch.nn.Linear(embed_dim, embed_dim, bias=bias) + self.v = torch.nn.Linear(embed_dim, embed_dim, bias=bias) + self.attn_drop = torch.nn.Dropout(dropout) + self.proj = torch.nn.Linear(embed_dim, embed_dim, bias) + self.proj_drop = torch.nn.Dropout(dropout) + self.batch_first = batch_first + self.auto_sparsity = auto_sparsity + + def forward( + self, + query, + key, + value, + key_padding_mask=None, + need_weights=True, + attn_mask=None, + average_attn_weights=True, + ): + + if not self.batch_first: # N, B, C + query = query.permute(1, 0, 2) # B, N_q, C + key = key.permute(1, 0, 2) # B, N_k, C + value = value.permute(1, 0, 2) # B, N_k, C + B, N_q, C = query.shape + B, N_k, C = key.shape + + q = ( + self.q(query) + .reshape(B, N_q, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + q = q.flatten(0, 1) + k = ( + self.k(key) + .reshape(B, N_k, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + k = k.flatten(0, 1) + v = ( + self.v(value) + .reshape(B, N_k, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + v = v.flatten(0, 1) + + if self.auto_sparsity: + assert attn_mask is None + x = dynamic_sparse_attention(q, k, v, sparsity=self.auto_sparsity) + else: + x = scaled_dot_product_attention(q, k, v, attn_mask, dropout=self.attn_drop) + x = x.reshape(B, self.num_heads, N_q, C // self.num_heads) + + x = x.transpose(1, 2).reshape(B, N_q, C) + x = self.proj(x) + x = self.proj_drop(x) + if not self.batch_first: + x = x.permute(1, 0, 2) + return x, None + + +def scaled_query_key_softmax(q, k, att_mask): + from xformers.ops import masked_matmul + q = q / (k.size(-1)) ** 0.5 + att = masked_matmul(q, k.transpose(-2, -1), att_mask) + att = torch.nn.functional.softmax(att, -1) + return att + + +def scaled_dot_product_attention(q, k, v, att_mask, dropout): + att = scaled_query_key_softmax(q, k, att_mask=att_mask) + att = dropout(att) + y = att @ v + return y + + +def _compute_buckets(x, R): + qq = torch.einsum('btf,bfhi->bhti', x, R) + qq = torch.cat([qq, -qq], dim=-1) + buckets = qq.argmax(dim=-1) + + return buckets.permute(0, 2, 1).byte().contiguous() + + +def dynamic_sparse_attention(query, key, value, sparsity, infer_sparsity=True, attn_bias=None): + # assert False, "The code for the custom sparse kernel is not ready for release yet." + from xformers.ops import find_locations, sparse_memory_efficient_attention + n_hashes = 32 + proj_size = 4 + query, key, value = [x.contiguous() for x in [query, key, value]] + with torch.no_grad(): + R = torch.randn(1, query.shape[-1], n_hashes, proj_size // 2, device=query.device) + bucket_query = _compute_buckets(query, R) + bucket_key = _compute_buckets(key, R) + row_offsets, column_indices = find_locations( + bucket_query, bucket_key, sparsity, infer_sparsity) + return sparse_memory_efficient_attention( + query, key, value, row_offsets, column_indices, attn_bias) diff --git a/AIMeiSheng/demucs/utils.py b/AIMeiSheng/demucs/utils.py new file mode 100644 index 0000000..a3f5993 --- /dev/null +++ b/AIMeiSheng/demucs/utils.py @@ -0,0 +1,149 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from collections import defaultdict +from concurrent.futures import CancelledError +from contextlib import contextmanager +import math +import os +import tempfile +import typing as tp + +import torch +from torch.nn import functional as F +from torch.utils.data import Subset + + +def unfold(a, kernel_size, stride): + """Given input of size [*OT, T], output Tensor of size [*OT, F, K] + with K the kernel size, by extracting frames with the given stride. + + This will pad the input so that `F = ceil(T / K)`. + + see https://github.com/pytorch/pytorch/issues/60466 + """ + *shape, length = a.shape + n_frames = math.ceil(length / stride) + tgt_length = (n_frames - 1) * stride + kernel_size + a = F.pad(a, (0, tgt_length - length)) + strides = list(a.stride()) + assert strides[-1] == 1, 'data should be contiguous' + strides = strides[:-1] + [stride, 1] + return a.as_strided([*shape, n_frames, kernel_size], strides) + + +def center_trim(tensor: torch.Tensor, reference: tp.Union[torch.Tensor, int]): + """ + Center trim `tensor` with respect to `reference`, along the last dimension. + `reference` can also be a number, representing the length to trim to. + If the size difference != 0 mod 2, the extra sample is removed on the right side. + """ + ref_size: int + if isinstance(reference, torch.Tensor): + ref_size = reference.size(-1) + else: + ref_size = reference + delta = tensor.size(-1) - ref_size + if delta < 0: + raise ValueError("tensor must be larger than reference. " f"Delta is {delta}.") + if delta: + tensor = tensor[..., delta // 2:-(delta - delta // 2)] + return tensor + + +def pull_metric(history: tp.List[dict], name: str): + out = [] + for metrics in history: + metric = metrics + for part in name.split("."): + metric = metric[part] + out.append(metric) + return out + + +def EMA(beta: float = 1): + """ + Exponential Moving Average callback. + Returns a single function that can be called to repeatidly update the EMA + with a dict of metrics. The callback will return + the new averaged dict of metrics. + + Note that for `beta=1`, this is just plain averaging. + """ + fix: tp.Dict[str, float] = defaultdict(float) + total: tp.Dict[str, float] = defaultdict(float) + + def _update(metrics: dict, weight: float = 1) -> dict: + nonlocal total, fix + for key, value in metrics.items(): + total[key] = total[key] * beta + weight * float(value) + fix[key] = fix[key] * beta + weight + return {key: tot / fix[key] for key, tot in total.items()} + return _update + + +def sizeof_fmt(num: float, suffix: str = 'B'): + """ + Given `num` bytes, return human readable size. + Taken from https://stackoverflow.com/a/1094933 + """ + for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']: + if abs(num) < 1024.0: + return "%3.1f%s%s" % (num, unit, suffix) + num /= 1024.0 + return "%.1f%s%s" % (num, 'Yi', suffix) + + +@contextmanager +def temp_filenames(count: int, delete=True): + names = [] + try: + for _ in range(count): + names.append(tempfile.NamedTemporaryFile(delete=False).name) + yield names + finally: + if delete: + for name in names: + os.unlink(name) + + +def random_subset(dataset, max_samples: int, seed: int = 42): + if max_samples >= len(dataset): + return dataset + + generator = torch.Generator().manual_seed(seed) + perm = torch.randperm(len(dataset), generator=generator) + return Subset(dataset, perm[:max_samples].tolist()) + + +class DummyPoolExecutor: + class DummyResult: + def __init__(self, func, _dict, *args, **kwargs): + self.func = func + self._dict = _dict + self.args = args + self.kwargs = kwargs + + def result(self): + if self._dict["run"]: + return self.func(*self.args, **self.kwargs) + else: + raise CancelledError() + + def __init__(self, workers=0): + self._dict = {"run": True} + + def submit(self, func, *args, **kwargs): + return DummyPoolExecutor.DummyResult(func, self._dict, *args, **kwargs) + + def shutdown(self, *_, **__): + self._dict["run"] = False + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, exc_tb): + return diff --git a/AIMeiSheng/demucs/wav.py b/AIMeiSheng/demucs/wav.py new file mode 100644 index 0000000..6acb9b5 --- /dev/null +++ b/AIMeiSheng/demucs/wav.py @@ -0,0 +1,254 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Loading wav based datasets, including MusdbHQ.""" + +from collections import OrderedDict +import hashlib +import math +import json +import os +from pathlib import Path +import tqdm + +import musdb +import julius +import torch as th +from torch import distributed +import torchaudio as ta +from torch.nn import functional as F + +from .audio import convert_audio_channels +from . import distrib + +MIXTURE = "mixture" +EXT = ".wav" + + +def _track_metadata(track, sources, normalize=True, ext=EXT): + track_length = None + track_samplerate = None + mean = 0 + std = 1 + for source in sources + [MIXTURE]: + file = track / f"{source}{ext}" + if source == MIXTURE and not file.exists(): + audio = 0 + for sub_source in sources: + sub_file = track / f"{sub_source}{ext}" + sub_audio, sr = ta.load(sub_file) + audio += sub_audio + would_clip = audio.abs().max() >= 1 + if would_clip: + assert ta.get_audio_backend() == 'soundfile', 'use dset.backend=soundfile' + ta.save(file, audio, sr, encoding='PCM_F') + + try: + info = ta.info(str(file)) + except RuntimeError: + print(file) + raise + length = info.num_frames + if track_length is None: + track_length = length + track_samplerate = info.sample_rate + elif track_length != length: + raise ValueError( + f"Invalid length for file {file}: " + f"expecting {track_length} but got {length}.") + elif info.sample_rate != track_samplerate: + raise ValueError( + f"Invalid sample rate for file {file}: " + f"expecting {track_samplerate} but got {info.sample_rate}.") + if source == MIXTURE and normalize: + try: + wav, _ = ta.load(str(file)) + except RuntimeError: + print(file) + raise + wav = wav.mean(0) + mean = wav.mean().item() + std = wav.std().item() + + return {"length": length, "mean": mean, "std": std, "samplerate": track_samplerate} + + +def build_metadata(path, sources, normalize=True, ext=EXT): + """ + Build the metadata for `Wavset`. + + Args: + path (str or Path): path to dataset. + sources (list[str]): list of sources to look for. + normalize (bool): if True, loads full track and store normalization + values based on the mixture file. + ext (str): extension of audio files (default is .wav). + """ + + meta = {} + path = Path(path) + pendings = [] + from concurrent.futures import ThreadPoolExecutor + with ThreadPoolExecutor(8) as pool: + for root, folders, files in os.walk(path, followlinks=True): + root = Path(root) + if root.name.startswith('.') or folders or root == path: + continue + name = str(root.relative_to(path)) + pendings.append((name, pool.submit(_track_metadata, root, sources, normalize, ext))) + # meta[name] = _track_metadata(root, sources, normalize, ext) + for name, pending in tqdm.tqdm(pendings, ncols=120): + meta[name] = pending.result() + return meta + + +class Wavset: + def __init__( + self, + root, metadata, sources, + segment=None, shift=None, normalize=True, + samplerate=44100, channels=2, ext=EXT): + """ + Waveset (or mp3 set for that matter). Can be used to train + with arbitrary sources. Each track should be one folder inside of `path`. + The folder should contain files named `{source}.{ext}`. + + Args: + root (Path or str): root folder for the dataset. + metadata (dict): output from `build_metadata`. + sources (list[str]): list of source names. + segment (None or float): segment length in seconds. If `None`, returns entire tracks. + shift (None or float): stride in seconds bewteen samples. + normalize (bool): normalizes input audio, **based on the metadata content**, + i.e. the entire track is normalized, not individual extracts. + samplerate (int): target sample rate. if the file sample rate + is different, it will be resampled on the fly. + channels (int): target nb of channels. if different, will be + changed onthe fly. + ext (str): extension for audio files (default is .wav). + + samplerate and channels are converted on the fly. + """ + self.root = Path(root) + self.metadata = OrderedDict(metadata) + self.segment = segment + self.shift = shift or segment + self.normalize = normalize + self.sources = sources + self.channels = channels + self.samplerate = samplerate + self.ext = ext + self.num_examples = [] + for name, meta in self.metadata.items(): + track_duration = meta['length'] / meta['samplerate'] + if segment is None or track_duration < segment: + examples = 1 + else: + examples = int(math.ceil((track_duration - self.segment) / self.shift) + 1) + self.num_examples.append(examples) + + def __len__(self): + return sum(self.num_examples) + + def get_file(self, name, source): + return self.root / name / f"{source}{self.ext}" + + def __getitem__(self, index): + for name, examples in zip(self.metadata, self.num_examples): + if index >= examples: + index -= examples + continue + meta = self.metadata[name] + num_frames = -1 + offset = 0 + if self.segment is not None: + offset = int(meta['samplerate'] * self.shift * index) + num_frames = int(math.ceil(meta['samplerate'] * self.segment)) + wavs = [] + for source in self.sources: + file = self.get_file(name, source) + wav, _ = ta.load(str(file), frame_offset=offset, num_frames=num_frames) + wav = convert_audio_channels(wav, self.channels) + wavs.append(wav) + + example = th.stack(wavs) + example = julius.resample_frac(example, meta['samplerate'], self.samplerate) + if self.normalize: + example = (example - meta['mean']) / meta['std'] + if self.segment: + length = int(self.segment * self.samplerate) + example = example[..., :length] + example = F.pad(example, (0, length - example.shape[-1])) + return example + + +def get_wav_datasets(args, name='wav'): + """Extract the wav datasets from the XP arguments.""" + path = getattr(args, name) + sig = hashlib.sha1(str(path).encode()).hexdigest()[:8] + metadata_file = Path(args.metadata) / ('wav_' + sig + ".json") + train_path = Path(path) / "train" + valid_path = Path(path) / "valid" + if not metadata_file.is_file() and distrib.rank == 0: + metadata_file.parent.mkdir(exist_ok=True, parents=True) + train = build_metadata(train_path, args.sources) + valid = build_metadata(valid_path, args.sources) + json.dump([train, valid], open(metadata_file, "w")) + if distrib.world_size > 1: + distributed.barrier() + train, valid = json.load(open(metadata_file)) + if args.full_cv: + kw_cv = {} + else: + kw_cv = {'segment': args.segment, 'shift': args.shift} + train_set = Wavset(train_path, train, args.sources, + segment=args.segment, shift=args.shift, + samplerate=args.samplerate, channels=args.channels, + normalize=args.normalize) + valid_set = Wavset(valid_path, valid, [MIXTURE] + list(args.sources), + samplerate=args.samplerate, channels=args.channels, + normalize=args.normalize, **kw_cv) + return train_set, valid_set + + +def _get_musdb_valid(): + # Return musdb valid set. + import yaml + setup_path = Path(musdb.__path__[0]) / 'configs' / 'mus.yaml' + setup = yaml.safe_load(open(setup_path, 'r')) + return setup['validation_tracks'] + + +def get_musdb_wav_datasets(args): + """Extract the musdb dataset from the XP arguments.""" + sig = hashlib.sha1(str(args.musdb).encode()).hexdigest()[:8] + metadata_file = Path(args.metadata) / ('musdb_' + sig + ".json") + root = Path(args.musdb) / "train" + if not metadata_file.is_file() and distrib.rank == 0: + metadata_file.parent.mkdir(exist_ok=True, parents=True) + metadata = build_metadata(root, args.sources) + json.dump(metadata, open(metadata_file, "w")) + if distrib.world_size > 1: + distributed.barrier() + metadata = json.load(open(metadata_file)) + + valid_tracks = _get_musdb_valid() + if args.train_valid: + metadata_train = metadata + else: + metadata_train = {name: meta for name, meta in metadata.items() if name not in valid_tracks} + metadata_valid = {name: meta for name, meta in metadata.items() if name in valid_tracks} + if args.full_cv: + kw_cv = {} + else: + kw_cv = {'segment': args.segment, 'shift': args.shift} + train_set = Wavset(root, metadata_train, args.sources, + segment=args.segment, shift=args.shift, + samplerate=args.samplerate, channels=args.channels, + normalize=args.normalize) + valid_set = Wavset(root, metadata_valid, [MIXTURE] + list(args.sources), + samplerate=args.samplerate, channels=args.channels, + normalize=args.normalize, **kw_cv) + return train_set, valid_set diff --git a/AIMeiSheng/demucs/wdemucs.py b/AIMeiSheng/demucs/wdemucs.py new file mode 100644 index 0000000..03d6dd3 --- /dev/null +++ b/AIMeiSheng/demucs/wdemucs.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# For compat +from .hdemucs import HDemucs + +WDemucs = HDemucs diff --git a/AIMeiSheng/docker_demo/Dockerfile b/AIMeiSheng/docker_demo/Dockerfile index 94fb28a..bdb7ab3 100644 --- a/AIMeiSheng/docker_demo/Dockerfile +++ b/AIMeiSheng/docker_demo/Dockerfile @@ -1,29 +1,31 @@ # 系统版本 CUDA Version 11.8.0 # NAME="CentOS Linux" VERSION="7 (Core)" # FROM starmaker.tencentcloudcr.com/starmaker/av/av:1.1 # 基础镜像, python3.9,cuda118,centos7,外加ffmpeg #FROM starmaker.tencentcloudcr.com/starmaker/av/av_base:1.0 FROM registry.ushow.media/av/av_base:1.0 #FROM av_base_test:1.0 RUN source /etc/profile && sed -i 's|mirrorlist=|#mirrorlist=|g' /etc/yum.repos.d/CentOS-Base.repo && sed -i 's|#baseurl=http://mirror.centos.org|baseurl=http://vault.centos.org|g' /etc/yum.repos.d/CentOS-Base.repo && yum clean all && yum install -y unzip && yum install -y libsndfile && yum install -y libsamplerate libsamplerate-devel RUN source /etc/profile && pip3 install librosa==0.9.1 && pip3 install gradio && pip3 install torch==2.1.2 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 RUN source /etc/profile && pip3 install urllib3==1.26.15 && pip3 install coscmd && coscmd config -a AKIDoQmshFWXGitnQmrfCTYNwEExPaU6RVHm -s F9n9E2ZonWy93f04qMaYFfogHadPt62h -b log-sg-1256122840 -r ap-singapore RUN source /etc/profile && pip3 install asteroid-filterbanks RUN source /etc/profile && pip3 install praat-parselmouth==0.4.3 RUN source /etc/profile && pip3 install pyworld RUN source /etc/profile && pip3 install faiss-cpu RUN source /etc/profile && pip3 install torchcrepe RUN source /etc/profile && pip3 install thop RUN source /etc/profile && pip3 install ffmpeg-python RUN source /etc/profile && pip3 install pip3==24.0 RUN source /etc/profile && pip3 install fairseq==0.12.2 RUN source /etc/profile && pip3 install redis==4.5.0 RUN source /etc/profile && pip3 install numpy==1.26.4 +RUN source /etc/profile && pip3 install demucs + COPY ./ /data/code/ WORKDIR /data/code CMD ["/bin/bash", "-c", "source /etc/profile; export PYTHONPATH=/data/code; cd /data/code/AIMeiSheng/docker_demo; python3 offline_server.py"] -#CMD ["/bin/bash", "-c", "source /etc/profile; export PYTHONPATH=/data/code; cd /data/code/AIMeiSheng/docker_demo; python3 tmp.py"] \ No newline at end of file +#CMD ["/bin/bash", "-c", "source /etc/profile; export PYTHONPATH=/data/code; cd /data/code/AIMeiSheng/docker_demo; python3 tmp.py"] diff --git a/AIMeiSheng/docker_demo/common.py b/AIMeiSheng/docker_demo/common.py index 49b3628..2b2a7b3 100644 --- a/AIMeiSheng/docker_demo/common.py +++ b/AIMeiSheng/docker_demo/common.py @@ -1,130 +1,132 @@ import os import sys import time # import logging import urllib, urllib.request # 测试/正式环境 gs_prod = True # if len(sys.argv) > 1 and sys.argv[1] == "prod": # gs_prod = True # print(gs_prod) gs_tmp_dir = "/data/ai_meisheng_tmp" gs_model_dir = "/data/ai_meisheng_models" gs_resource_cache_dir = "/tmp/ai_meisheng_resource_cache" gs_embed_model_path = os.path.join(gs_model_dir, "RawNet3/models/weights/model.pt") gs_svc_model_path = os.path.join(gs_model_dir, "weights/xusong_v2_org_version_alldata_embed_spkenx200x_double_e14_s90706.pth") gs_hubert_model_path = os.path.join(gs_model_dir, "hubert.pt") gs_rmvpe_model_path = os.path.join(gs_model_dir, "rmvpe.pt") gs_embed_model_spk_path = os.path.join(gs_model_dir, "SpeakerEncoder/pretrained_model/best_model.pth.tar") gs_embed_config_spk_path = os.path.join(gs_model_dir, "SpeakerEncoder/pretrained_model/config.json") +gs_demucs_model_path = os.path.join(gs_model_dir, "demucs_model/") + # errcode gs_err_code_success = 0 gs_err_code_download_vocal = 100 gs_err_code_download_svc_url = 101 gs_err_code_svc_process = 102 gs_err_code_transcode = 103 gs_err_code_volume_adjust = 104 gs_err_code_upload = 105 gs_err_code_params = 106 gs_err_code_pending = 107 gs_err_code_target_silence = 108 gs_err_code_too_many_connections = 429 gs_err_code_gender_classify = 430 gs_err_code_vocal_ratio = 431 #人声占比 gs_redis_conf = { "host": "av-credis.starmaker.co", "port": 6379, "pwd": "lKoWEhz%jxTO", } # gs_server_redis_conf = { # "producer": "dev_ai_meisheng_producer", # 输入的队列 # "ai_meisheng_key_prefix": "dev_ai_meisheng_key_", # 存储结果情况 # } gs_server_redis_conf = { "producer": "test_ai_meisheng_producer", # 输入的队列 "ai_meisheng_key_prefix": "test_ai_meisheng_key_", # 存储结果情况 } if gs_prod: gs_server_redis_conf = { "producer": "ai_meisheng_producer", # 输入的队列 "ai_meisheng_key_prefix": "ai_meisheng_key_", # 存储结果情况 } gs_feishu_conf = { "url": "http://sg-prod-songbook-webmp-1:8000/api/feishu/people", "users": [ "18810833785", # 杨建利 "17778007843", # 王健军 "18612496315", # 郭子豪 "18600542290" # 方兵晓 ] } def download2disk(url, dst_path): try: urllib.request.urlretrieve(url, dst_path) return os.path.exists(dst_path) except Exception as ex: print(f"download url={url} error", ex) return False def exec_cmd(cmd): # gs_logger.info(cmd) print(cmd) ret = os.system(cmd) if ret != 0: return False return True def exec_cmd_and_result(cmd): r = os.popen(cmd) text = r.read() r.close() return text def upload_file2cos(key, file_path, region='ap-singapore', bucket_name='av-audit-sync-sg-1256122840'): """ 将文件上传到cos :param key: 桶上的具体地址 :param file_path: 本地文件地址 :param region: 区域 :param bucket_name: 桶地址 :return: """ gs_coscmd = "coscmd" gs_coscmd_conf = "~/.cos.conf" cmd = "{} -c {} -r {} -b {} upload {} {}".format(gs_coscmd, gs_coscmd_conf, region, bucket_name, file_path, key) if exec_cmd(cmd): cmd = "{} -c {} -r {} -b {} info {}".format(gs_coscmd, gs_coscmd_conf, region, bucket_name, key) \ + "| grep Content-Length |awk \'{print $2}\'" res_str = exec_cmd_and_result(cmd) # logging.info("{},res={}".format(key, res_str)) size = float(res_str) if size > 0: return True return False return False def check_input(input_data): key_list = ["record_song_url", "target_url", "start", "end", "vocal_loudness", "female_recording_url", "male_recording_url"] for key in key_list: if key not in input_data.keys(): return False return True diff --git a/AIMeiSheng/docker_demo/svc_online.py b/AIMeiSheng/docker_demo/svc_online.py index 421b5f6..178b32d 100644 --- a/AIMeiSheng/docker_demo/svc_online.py +++ b/AIMeiSheng/docker_demo/svc_online.py @@ -1,207 +1,220 @@ # -*- coding: UTF-8 -*- """ SVC的核心处理逻辑 """ import os import time import socket import shutil import hashlib from AIMeiSheng.meisheng_svc_final import load_model, process_svc_online from AIMeiSheng.cos_similar_ui_zoom import cos_similar -from AIMeiSheng.meisheng_env_preparex import meisheng_env_prepare +from AIMeiSheng.meisheng_env_preparex import meisheng_env_prepare, demucs_env_prepare from AIMeiSheng.voice_classification.online.voice_class_online_fang import VoiceClass, download_volume_balanced from AIMeiSheng.docker_demo.common import * +# from AIMeiSheng.demucs_process_one import demucs_process_one +from AIMeiSheng.separate_demucs import main_seperator import logging hostname = socket.gethostname() log_file_name = f"{os.path.dirname(os.path.abspath(__file__))}/av_meisheng_{hostname}.log" # 设置logger svc_offline_logger = logging.getLogger("svc_offline") file_handler = logging.FileHandler(log_file_name) file_handler.setLevel(logging.INFO) formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s', datefmt='%Y-%m-%d %I:%M:%S') file_handler.setFormatter(formatter) # if gs_prod: # svc_offline_logger.addHandler(file_handler) if os.path.exists(gs_tmp_dir): shutil.rmtree(gs_tmp_dir) os.makedirs(gs_model_dir, exist_ok=True) os.makedirs(gs_resource_cache_dir, exist_ok=True) # 预设参数 gs_gender_models_url = "https://av-audit-sync-sg-1256122840.cos.ap-singapore.myqcloud.com/hub/voice_classification/models.zip" gs_volume_bin_url = "https://av-audit-sync-sg-1256122840.cos.ap-singapore.myqcloud.com/dataset/AIMeiSheng/ebur128_tool/v1/ebur128_tool" class GSWorkerAttr: def __init__(self, input_data): # 取出输入资源 vocal_url = input_data["record_song_url"] target_url = input_data["target_url"] start = input_data["start"] # 单位是ms end = input_data["end"] # 单位是ms vocal_loudness = input_data["vocal_loudness"] female_recording_url = input_data["female_recording_url"] male_recording_url = input_data["male_recording_url"] self.distinct_id = hashlib.md5(vocal_url.encode()).hexdigest() self.tmp_dir = os.path.join(gs_tmp_dir, self.distinct_id) if os.path.exists(self.tmp_dir): shutil.rmtree(self.tmp_dir) os.makedirs(self.tmp_dir) self.vocal_url = vocal_url self.target_url = target_url ext = vocal_url.split(".")[-1] self.vocal_path = os.path.join(self.tmp_dir, self.distinct_id + f"_in.{ext}") + self.vocal_demucs_path = os.path.join(self.tmp_dir, self.distinct_id + f"_in_demucs.wav") self.target_wav_path = os.path.join(self.tmp_dir, self.distinct_id + "_out.wav") self.target_wav_ad_path = os.path.join(self.tmp_dir, self.distinct_id + "_out_ad.wav") self.target_path = os.path.join(self.tmp_dir, self.distinct_id + "_out.m4a") self.female_svc_source_url = female_recording_url self.male_svc_source_url = male_recording_url ext = female_recording_url.split(".")[-1] self.female_svc_source_path = os.path.join(self.tmp_dir, self.distinct_id + f"_female.{ext}") ext = male_recording_url.split(".")[-1] self.male_svc_source_path = os.path.join(self.tmp_dir, self.distinct_id + f"_male.{ext}") # self.female_svc_source_path = os.path.join(gs_resource_cache_dir, # hashlib.md5(female_recording_url.encode()).hexdigest() + "." + ext) # ext = male_recording_url.split(".")[-1] # self.male_svc_source_path = os.path.join(gs_resource_cache_dir, # hashlib.md5(male_recording_url.encode()).hexdigest() + "." + ext) self.st_tm = start self.ed_tm = end self.target_loudness = vocal_loudness def log_info_name(self): return f"d_id={self.distinct_id}, vocal_url={self.vocal_url}" def rm_cache(self): if os.path.exists(self.tmp_dir): shutil.rmtree(self.tmp_dir) def init_gender_model(): """ 下载模型 :return: """ dst_model_dir = os.path.join(gs_model_dir, "voice_classification") if not os.path.exists(dst_model_dir): dst_zip_path = os.path.join(gs_model_dir, "models.zip") if not download2disk(gs_gender_models_url, dst_zip_path): svc_offline_logger.fatal(f"download gender_model err={gs_gender_models_url}") cmd = f"cd {gs_model_dir}; unzip {dst_zip_path}; mv models voice_classification; rm -f {dst_zip_path}" os.system(cmd) if not os.path.exists(dst_model_dir): svc_offline_logger.fatal(f"unzip {dst_zip_path} err") music_voice_pure_model = os.path.join(dst_model_dir, "voice_005_rec_v5.pth") music_voice_no_pure_model = os.path.join(dst_model_dir, "voice_10_v5.pth") gender_pure_model = os.path.join(dst_model_dir, "gender_8k_ratev5_v6_adam.pth") gender_no_pure_model = os.path.join(dst_model_dir, "gender_8k_v6_adam.pth") vc = VoiceClass(music_voice_pure_model, music_voice_no_pure_model, gender_pure_model, gender_no_pure_model) return vc +def init_demucs_seperator(): + demucs_env_prepare(logging, gs_demucs_model_path) + seperator = main_seperator() + return seperator def init_svc_model(): meisheng_env_prepare(logging, gs_model_dir) embed_model, hubert_model = load_model() cs_sim = cos_similar() return embed_model, hubert_model, cs_sim def download_volume_adjustment(): """ 下载音量调整工具 :return: """ volume_bin_path = os.path.join(gs_model_dir, "ebur128_tool") if not os.path.exists(volume_bin_path): if not download2disk(gs_volume_bin_url, volume_bin_path): svc_offline_logger.fatal(f"download volume_bin err={gs_volume_bin_url}") os.system(f"chmod +x {volume_bin_path}") def volume_adjustment(wav_path, target_loudness, out_path): """ 音量调整 :param wav_path: :param target_loudness: :param out_path: :return: """ volume_bin_path = os.path.join(gs_model_dir, "ebur128_tool") cmd = f"{volume_bin_path} {wav_path} {target_loudness} {out_path}" os.system(cmd) class SVCOnline: def __init__(self): st = time.time() self.gender_model = init_gender_model() self.embed_model, self.hubert_model, self.cs_sim = init_svc_model() + self.demucs_seprator = init_demucs_seperator() download_volume_adjustment() download_volume_balanced() svc_offline_logger.info(f"svc init finished, sp = {time.time() - st}") def gender_process(self, worker_attr): st = time.time() gender, female_rate, is_pure = self.gender_model.process(worker_attr.vocal_path) svc_offline_logger.info( f"{worker_attr.vocal_url}, gender={gender}, female_rate={female_rate}, is_pure={is_pure}, " f"gender_process sp = {time.time() - st}") if gender == 0: gender = 'female' elif gender == 1: gender = 'male' elif female_rate == None: gender = 'male' return gender, gs_err_code_gender_classify elif female_rate > 0.5: gender = 'female' else: gender = 'male' if gender == 'female': if self.gender_model.vocal_ratio < 0.5: print(f"@@@ vocal_ratio: {self.gender_model.vocal_ratio}, gender : {gender}, gs_err_code_vocal_ratio : {gs_err_code_vocal_ratio}") - return gender, gs_err_code_vocal_ratio + # if not demucs_process_one(worker_attr.vocal_path, worker_attr.vocal_demucs_path): + if not self.demucs_seprator.process_one(worker_attr.vocal_path, worker_attr.vocal_demucs_path): + return gender, gs_err_code_vocal_ratio else: if self.gender_model.vocal_ratio < 0.6: print(f"@@@ vocal_ratio: {self.gender_model.vocal_ratio}, gender : {gender}, gs_err_code_vocal_ratio : {gs_err_code_vocal_ratio}") - return gender, gs_err_code_vocal_ratio + # if not demucs_process_one(worker_attr.vocal_path, worker_attr.vocal_demucs_path): + if not self.demucs_seprator.process_one(worker_attr.vocal_path, worker_attr.vocal_demucs_path): + return gender, gs_err_code_vocal_ratio + svc_offline_logger.info(f"{worker_attr.vocal_url}, modified gender={gender}") # err = gs_err_code_success # if female_rate == -1: # err = gs_err_code_target_silence return gender, gs_err_code_success def process(self, worker_attr): gender, err = self.gender_process(worker_attr) if err != gs_err_code_success: return gender, err song_path = worker_attr.female_svc_source_path if gender == "male": song_path = worker_attr.male_svc_source_path params = {'gender': gender, 'tst': worker_attr.st_tm, "tnd": worker_attr.ed_tm, 'delay': 0, 'song_path': None} st = time.time() err_code = process_svc_online(song_path, worker_attr.vocal_path, worker_attr.target_wav_path, self.embed_model, self.hubert_model, self.cs_sim, params) svc_offline_logger.info(f"{worker_attr.vocal_url}, err_code={err_code} process svc sp = {time.time() - st}") return gender, err_code diff --git a/AIMeiSheng/meisheng_env_preparex.py b/AIMeiSheng/meisheng_env_preparex.py index 079ba38..f37fc2d 100644 --- a/AIMeiSheng/meisheng_env_preparex.py +++ b/AIMeiSheng/meisheng_env_preparex.py @@ -1,56 +1,71 @@ import os from AIMeiSheng.docker_demo.common import (gs_svc_model_path, gs_hubert_model_path, gs_embed_model_path,gs_embed_model_spk_path, gs_embed_config_spk_path, gs_rmvpe_model_path, download2disk) def meisheng_env_prepare(logging, AIMeiSheng_Path='./'): cos_path = "https://av-audit-sync-sg-1256122840.cos.ap-singapore.myqcloud.com/dataset/AIMeiSheng/" rmvpe_model_url = cos_path + "rmvpe.pt" if not os.path.exists(gs_rmvpe_model_path): if not download2disk(rmvpe_model_url, gs_rmvpe_model_path): logging.fatal(f"download rmvpe_model err={rmvpe_model_url}") gs_hubert_model_url = cos_path + "hubert_base.pt" if not os.path.exists(gs_hubert_model_path): if not download2disk(gs_hubert_model_url, gs_hubert_model_path): logging.fatal(f"download hubert_model err={gs_hubert_model_url}") #model_svc = "xusong_v2_org_version_alldata_embed1_enzx_diff_fi_e15_s244110.pth" #model_svc = "xusong_v2_org_version_alldata_embed1_enzx_diff_ocean_ctl_enc_e22_s363704.pth" #model_svc = "xusong_v2_org_version_alldata_embed_spkenx200x_vocal_e22_s95040.pth" model_svc = "xusong_v2_org_version_alldata_embed_spkenx200x_double_e14_s90706.pth" base_dir = os.path.dirname(gs_svc_model_path) os.makedirs(base_dir, exist_ok=True) svc_model_url = cos_path + model_svc if not os.path.exists(gs_svc_model_path): if not download2disk(svc_model_url, gs_svc_model_path): logging.fatal(f"download svc_model err={svc_model_url}") model_embed = "model.pt" base_dir = os.path.dirname(gs_embed_model_path) os.makedirs(base_dir, exist_ok=True) embed_model_url = cos_path + model_embed if not os.path.exists(gs_embed_model_path): if not download2disk(embed_model_url, gs_embed_model_path): logging.fatal(f"download embed_model err={embed_model_url}") model_spk_embed = "best_model.pth.tar" base_dir = os.path.dirname(gs_embed_model_spk_path) os.makedirs(base_dir, exist_ok=True) embed_model_url = cos_path + model_spk_embed if not os.path.exists(gs_embed_model_spk_path): if not download2disk(embed_model_url, gs_embed_model_spk_path): logging.fatal(f"download embed_model err={embed_model_url}") model_spk_embed_cfg = "config.json" base_dir = os.path.dirname(gs_embed_config_spk_path) os.makedirs(base_dir, exist_ok=True) embed_model_url = cos_path + model_spk_embed_cfg if not os.path.exists(gs_embed_config_spk_path): if not download2disk(embed_model_url, gs_embed_config_spk_path): logging.fatal(f"download embed_model err={embed_model_url}") + +def demucs_env_prepare(logging,gs_demucs_model_path): + cos_path = "https://av-audit-sync-sg-1256122840.cos.ap-singapore.myqcloud.com/dataset/AIMeiSheng/" + + model_demucs_name_list = ["e51eebcc-c1b80bdd.th", "a1d90b5c-ae9d2452.th", "5d2d6c55-db83574e.th", "cfa93e08-61801ae1.th"] + if not os.path.exists(gs_demucs_model_path): + os.makedirs(gs_demucs_model_path, exist_ok=True) + for model_part in model_demucs_name_list: + demucs_model_cos_url = cos_path + model_part + gs_demucs_model_path_tmp = gs_demucs_model_path + model_part + if not download2disk(demucs_model_cos_url, gs_demucs_model_path_tmp): + logging.fatal(f"download embed_model err={demucs_model_cos_url}") + + + if __name__ == "__main__": meisheng_env_prepare() diff --git a/AIMeiSheng/separate_demucs.py b/AIMeiSheng/separate_demucs.py new file mode 100644 index 0000000..99a28a8 --- /dev/null +++ b/AIMeiSheng/separate_demucs.py @@ -0,0 +1,264 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import sys, os +from pathlib import Path + +from dora.log import fatal +import torch as th + +from AIMeiSheng.demucs.api import Separator, save_audio, list_models + +from AIMeiSheng.demucs.apply import BagOfModels +from AIMeiSheng.demucs.htdemucs import HTDemucs +from AIMeiSheng.demucs.pretrained import add_model_flags, ModelLoadingError + + +def get_parser(): + parser = argparse.ArgumentParser("demucs.separate", + description="Separate the sources for the given tracks") + parser.add_argument("tracks", nargs='*', type=Path, default=[], help='Path to tracks') + add_model_flags(parser) + parser.add_argument("--list-models", action="store_true", help="List available models " + "from current repo and exit") + parser.add_argument("-v", "--verbose", action="store_true") + parser.add_argument("-o", + "--out", + type=Path, + default=Path("separated"), + help="Folder where to put extracted tracks. A subfolder " + "with the model name will be created.") + + parser.add_argument("--filename", + default="{track}/{stem}.{ext}", + help="Set the name of output file. \n" + 'Use "{track}", "{trackext}", "{stem}", "{ext}" to use ' + "variables of track name without extension, track extension, " + "stem name and default output file extension. \n" + 'Default is "{track}/{stem}.{ext}".') + parser.add_argument("-d", + "--device", + default=( + "cuda" + if th.cuda.is_available() + else "mps" + if th.backends.mps.is_available() + else "cpu" + ), + help="Device to use, default is cuda if available else cpu") + parser.add_argument("--shifts", + default=1, + type=int, + help="Number of random shifts for equivariant stabilization." + "Increase separation time but improves quality for Demucs. 10 was used " + "in the original paper.") + parser.add_argument("--overlap", + default=0.25, + type=float, + help="Overlap between the splits.") + split_group = parser.add_mutually_exclusive_group() + split_group.add_argument("--no-split", + action="store_false", + dest="split", + default=True, + help="Doesn't split audio in chunks. " + "This can use large amounts of memory.") + split_group.add_argument("--segment", type=int, + help="Set split size of each chunk. " + "This can help save memory of graphic card. ") + parser.add_argument("--two-stems", + dest="stem", metavar="STEM", + help="Only separate audio into {STEM} and no_{STEM}. ") + parser.add_argument("--other-method", dest="other_method", choices=["none", "add", "minus"], + default="add", help='Decide how to get "no_{STEM}". "none" will not save ' + '"no_{STEM}". "add" will add all the other stems. "minus" will use the ' + "original track minus the selected stem.") + depth_group = parser.add_mutually_exclusive_group() + depth_group.add_argument("--int24", action="store_true", + help="Save wav output as 24 bits wav.") + depth_group.add_argument("--float32", action="store_true", + help="Save wav output as float32 (2x bigger).") + parser.add_argument("--clip-mode", default="rescale", choices=["rescale", "clamp", "none"], + help="Strategy for avoiding clipping: rescaling entire signal " + "if necessary (rescale) or hard clipping (clamp).") + format_group = parser.add_mutually_exclusive_group() + format_group.add_argument("--flac", action="store_true", + help="Convert the output wavs to flac.") + format_group.add_argument("--mp3", action="store_true", + help="Convert the output wavs to mp3.") + parser.add_argument("--mp3-bitrate", + default=320, + type=int, + help="Bitrate of converted mp3.") + parser.add_argument("--mp3-preset", choices=range(2, 8), type=int, default=2, + help="Encoder preset of MP3, 2 for highest quality, 7 for " + "fastest speed. Default is 2") + parser.add_argument("-j", "--jobs", + default=0, + type=int, + help="Number of jobs. This can increase memory usage but will " + "be much faster when multiple cores are available.") + + return parser + + +def main_init(opts=None): + parser = get_parser() + args = parser.parse_args(opts) + args.name = 'mdx_extra' + ''' + print("args.list_models:",args.list_models) + if args.list_models: + models = list_models(args.repo) + print("Bag of models:", end="\n ") + print("\n ".join(models["bag"])) + print("Single models:", end="\n ") + print("\n ".join(models["single"])) + sys.exit(0) + #''' + #if len(args.tracks) == 0: + # print("error: the following arguments are required: tracks", file=sys.stderr) + # sys.exit(1) + #print("args:", args) + try: + separator = Separator(model=args.name, + repo=args.repo, + device=args.device, + shifts=args.shifts, + split=args.split, + overlap=args.overlap, + progress=True, + jobs=args.jobs, + segment=args.segment) + except ModelLoadingError as error: + fatal(error.args[0]) + + return separator, args + +def main(separator,args): + #print(args) + max_allowed_segment = float('inf') + if isinstance(separator.model, HTDemucs): + max_allowed_segment = float(separator.model.segment) + elif isinstance(separator.model, BagOfModels): + max_allowed_segment = separator.model.max_allowed_segment + if args.segment is not None and args.segment > max_allowed_segment: + fatal("Cannot use a Transformer model with a longer segment " + f"than it was trained for. Maximum segment is: {max_allowed_segment}") + + if isinstance(separator.model, BagOfModels): + print( + f"Selected model is a bag of {len(separator.model.models)} models. " + "You will see that many progress bars per track." + ) + + if args.stem is not None and args.stem not in separator.model.sources: + fatal( + 'error: stem "{stem}" is not in selected model. ' + "STEM must be one of {sources}.".format( + stem=args.stem, sources=", ".join(separator.model.sources) + ) + ) + out = args.out #/ args.name + out.mkdir(parents=True, exist_ok=True) + print(f"Separated tracks will be stored in {out.resolve()}") + for track in args.tracks: + if not track.exists(): + print(f"File {track} does not exist. If the path contains spaces, " + 'please try again after surrounding the entire path with quotes "".', + file=sys.stderr) + continue + print(f"Separating track {track}") + + origin, res = separator.separate_audio_file(track) + + if args.mp3: + ext = "mp3" + elif args.flac: + ext = "flac" + else: + ext = "wav" + kwargs = { + "samplerate": separator.samplerate, + "bitrate": args.mp3_bitrate, + "preset": args.mp3_preset, + "clip": args.clip_mode, + "as_float": args.float32, + "bits_per_sample": 24 if args.int24 else 16, + } + #print("@@@args.stem:", args.stem) + if args.stem is None: + for name, source in res.items(): + stem = out / args.filename.format( + track=track.name.rsplit(".", 1)[0], + trackext=track.name.rsplit(".", 1)[-1], + stem=name, + ext=ext, + ) + #print("source:",source.shape) + stem.parent.mkdir(parents=True, exist_ok=True) + #print("@@@@str(stem):", str(stem)) + save_audio(source, str(stem), **kwargs) + else: + stem = out / args.filename.format( + track=track.name.rsplit(".", 1)[0], + trackext=track.name.rsplit(".", 1)[-1], + stem="minus_" + args.stem, + ext=ext, + ) + if args.other_method == "minus": + stem.parent.mkdir(parents=True, exist_ok=True) + save_audio(origin - res[args.stem], str(stem), **kwargs) + stem = out / args.filename.format( + track=track.name.rsplit(".", 1)[0], + trackext=track.name.rsplit(".", 1)[-1], + stem=args.stem, + ext=ext, + ) + stem.parent.mkdir(parents=True, exist_ok=True) + save_audio(res.pop(args.stem), str(stem), **kwargs) + # Warning : after poping the stem, selected stem is no longer in the dict 'res' + if args.other_method == "add": + other_stem = th.zeros_like(next(iter(res.values()))) + for i in res.values(): + other_stem += i + stem = out / args.filename.format( + track=track.name.rsplit(".", 1)[0], + trackext=track.name.rsplit(".", 1)[-1], + stem="no_" + args.stem, + ext=ext, + ) + stem.parent.mkdir(parents=True, exist_ok=True) + save_audio(other_stem, str(stem), **kwargs) + return str(stem) + +class main_seperator(): + def __init__(self, modelname = 'mdx_extra'): + # cmd = f'-o {outpath} --filename {out_name} -n {modelname} {in_name}' + + ''' + @@@out: test_out/mdx_extra + @@@args.name: mdx_extra + track: [PosixPath('test_wav/千年之恋_2-a.wav')] + args.stem: None + ''' + self.separator, self.args = main_init() + + def process_one(self, in_name, out_path): + self.args.tracks = [Path(in_name)] ##输入名字 + self.args.out = Path(os.path.dirname(out_path)) #输出路径 + self.args.filename = os.path.basename(out_path)#[:-4]+'.wav'#out_basename ##输出名字 + outpath = main(self.separator, self.args) + #outpath = os.path.join(out_path, self.args.name ,self.args.filename) + #print("@@@@demucs outpath:", outpath) + return outpath + +if __name__ == "__main__": + in_wav = "test_wav/千年之恋_2-a.wav" + out_wav = "test_out/out.wav" + seperator = main_seperator() + seperator.process_one(in_wav, out_wav)