Page MenuHomePhabricator

No OneTemporary

This file is larger than 256 KB, so syntax highlighting was skipped.
diff --git a/AIMeiSheng/demucs/__init__.py b/AIMeiSheng/demucs/__init__.py
new file mode 100644
index 0000000..e02c0ad
--- /dev/null
+++ b/AIMeiSheng/demucs/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+__version__ = "4.1.0a2"
diff --git a/AIMeiSheng/demucs/__main__.py b/AIMeiSheng/demucs/__main__.py
new file mode 100644
index 0000000..da0a541
--- /dev/null
+++ b/AIMeiSheng/demucs/__main__.py
@@ -0,0 +1,10 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .separate import main
+
+if __name__ == '__main__':
+ main()
diff --git a/AIMeiSheng/demucs/api.py b/AIMeiSheng/demucs/api.py
new file mode 100644
index 0000000..20079a6
--- /dev/null
+++ b/AIMeiSheng/demucs/api.py
@@ -0,0 +1,392 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""API methods for demucs
+
+Classes
+-------
+`demucs.api.Separator`: The base separator class
+
+Functions
+---------
+`demucs.api.save_audio`: Save an audio
+`demucs.api.list_models`: Get models list
+
+Examples
+--------
+See the end of this module (if __name__ == "__main__")
+"""
+
+import subprocess
+
+import torch as th
+import torchaudio as ta
+
+from dora.log import fatal
+from pathlib import Path
+from typing import Optional, Callable, Dict, Tuple, Union
+
+from .apply import apply_model, _replace_dict
+from .audio import AudioFile, convert_audio, save_audio
+from .pretrained import get_model, _parse_remote_files, REMOTE_ROOT
+from .repo import RemoteRepo, LocalRepo, ModelOnlyRepo, BagOnlyRepo
+
+
+class LoadAudioError(Exception):
+ pass
+
+
+class LoadModelError(Exception):
+ pass
+
+
+class _NotProvided:
+ pass
+
+
+NotProvided = _NotProvided()
+
+
+class Separator:
+ def __init__(
+ self,
+ model: str = "htdemucs",
+ repo: Optional[Path] = None,
+ device: str = "cuda" if th.cuda.is_available() else "cpu",
+ shifts: int = 1,
+ overlap: float = 0.25,
+ split: bool = True,
+ segment: Optional[int] = None,
+ jobs: int = 0,
+ progress: bool = False,
+ callback: Optional[Callable[[dict], None]] = None,
+ callback_arg: Optional[dict] = None,
+ ):
+ """
+ `class Separator`
+ =================
+
+ Parameters
+ ----------
+ model: Pretrained model name or signature. Default is htdemucs.
+ repo: Folder containing all pre-trained models for use.
+ segment: Length (in seconds) of each segment (only available if `split` is `True`). If \
+ not specified, will use the command line option.
+ shifts: If > 0, will shift in time `wav` by a random amount between 0 and 0.5 sec and \
+ apply the oppositve shift to the output. This is repeated `shifts` time and all \
+ predictions are averaged. This effectively makes the model time equivariant and \
+ improves SDR by up to 0.2 points. If not specified, will use the command line option.
+ split: If True, the input will be broken down into small chunks (length set by `segment`) \
+ and predictions will be performed individually on each and concatenated. Useful for \
+ model with large memory footprint like Tasnet. If not specified, will use the command \
+ line option.
+ overlap: The overlap between the splits. If not specified, will use the command line \
+ option.
+ device (torch.device, str, or None): If provided, device on which to execute the \
+ computation, otherwise `wav.device` is assumed. When `device` is different from \
+ `wav.device`, only local computations will be on `device`, while the entire tracks \
+ will be stored on `wav.device`. If not specified, will use the command line option.
+ jobs: Number of jobs. This can increase memory usage but will be much faster when \
+ multiple cores are available. If not specified, will use the command line option.
+ callback: A function will be called when the separation of a chunk starts or finished. \
+ The argument passed to the function will be a dict. For more information, please see \
+ the Callback section.
+ callback_arg: A dict containing private parameters to be passed to callback function. For \
+ more information, please see the Callback section.
+ progress: If true, show a progress bar.
+
+ Callback
+ --------
+ The function will be called with only one positional parameter whose type is `dict`. The
+ `callback_arg` will be combined with information of current separation progress. The
+ progress information will override the values in `callback_arg` if same key has been used.
+ To abort the separation, raise `KeyboardInterrupt`.
+
+ Progress information contains several keys (These keys will always exist):
+ - `model_idx_in_bag`: The index of the submodel in `BagOfModels`. Starts from 0.
+ - `shift_idx`: The index of shifts. Starts from 0.
+ - `segment_offset`: The offset of current segment. If the number is 441000, it doesn't
+ mean that it is at the 441000 second of the audio, but the "frame" of the tensor.
+ - `state`: Could be `"start"` or `"end"`.
+ - `audio_length`: Length of the audio (in "frame" of the tensor).
+ - `models`: Count of submodels in the model.
+ """
+ self._name = model
+ self._repo = repo
+ self._load_model()
+ self.update_parameter(device=device, shifts=shifts, overlap=overlap, split=split,
+ segment=segment, jobs=jobs, progress=progress, callback=callback,
+ callback_arg=callback_arg)
+
+ def update_parameter(
+ self,
+ device: Union[str, _NotProvided] = NotProvided,
+ shifts: Union[int, _NotProvided] = NotProvided,
+ overlap: Union[float, _NotProvided] = NotProvided,
+ split: Union[bool, _NotProvided] = NotProvided,
+ segment: Optional[Union[int, _NotProvided]] = NotProvided,
+ jobs: Union[int, _NotProvided] = NotProvided,
+ progress: Union[bool, _NotProvided] = NotProvided,
+ callback: Optional[
+ Union[Callable[[dict], None], _NotProvided]
+ ] = NotProvided,
+ callback_arg: Optional[Union[dict, _NotProvided]] = NotProvided,
+ ):
+ """
+ Update the parameters of separation.
+
+ Parameters
+ ----------
+ segment: Length (in seconds) of each segment (only available if `split` is `True`). If \
+ not specified, will use the command line option.
+ shifts: If > 0, will shift in time `wav` by a random amount between 0 and 0.5 sec and \
+ apply the oppositve shift to the output. This is repeated `shifts` time and all \
+ predictions are averaged. This effectively makes the model time equivariant and \
+ improves SDR by up to 0.2 points. If not specified, will use the command line option.
+ split: If True, the input will be broken down into small chunks (length set by `segment`) \
+ and predictions will be performed individually on each and concatenated. Useful for \
+ model with large memory footprint like Tasnet. If not specified, will use the command \
+ line option.
+ overlap: The overlap between the splits. If not specified, will use the command line \
+ option.
+ device (torch.device, str, or None): If provided, device on which to execute the \
+ computation, otherwise `wav.device` is assumed. When `device` is different from \
+ `wav.device`, only local computations will be on `device`, while the entire tracks \
+ will be stored on `wav.device`. If not specified, will use the command line option.
+ jobs: Number of jobs. This can increase memory usage but will be much faster when \
+ multiple cores are available. If not specified, will use the command line option.
+ callback: A function will be called when the separation of a chunk starts or finished. \
+ The argument passed to the function will be a dict. For more information, please see \
+ the Callback section.
+ callback_arg: A dict containing private parameters to be passed to callback function. For \
+ more information, please see the Callback section.
+ progress: If true, show a progress bar.
+
+ Callback
+ --------
+ The function will be called with only one positional parameter whose type is `dict`. The
+ `callback_arg` will be combined with information of current separation progress. The
+ progress information will override the values in `callback_arg` if same key has been used.
+ To abort the separation, raise `KeyboardInterrupt`.
+
+ Progress information contains several keys (These keys will always exist):
+ - `model_idx_in_bag`: The index of the submodel in `BagOfModels`. Starts from 0.
+ - `shift_idx`: The index of shifts. Starts from 0.
+ - `segment_offset`: The offset of current segment. If the number is 441000, it doesn't
+ mean that it is at the 441000 second of the audio, but the "frame" of the tensor.
+ - `state`: Could be `"start"` or `"end"`.
+ - `audio_length`: Length of the audio (in "frame" of the tensor).
+ - `models`: Count of submodels in the model.
+ """
+ if not isinstance(device, _NotProvided):
+ self._device = device
+ if not isinstance(shifts, _NotProvided):
+ self._shifts = shifts
+ if not isinstance(overlap, _NotProvided):
+ self._overlap = overlap
+ if not isinstance(split, _NotProvided):
+ self._split = split
+ if not isinstance(segment, _NotProvided):
+ self._segment = segment
+ if not isinstance(jobs, _NotProvided):
+ self._jobs = jobs
+ if not isinstance(progress, _NotProvided):
+ self._progress = progress
+ if not isinstance(callback, _NotProvided):
+ self._callback = callback
+ if not isinstance(callback_arg, _NotProvided):
+ self._callback_arg = callback_arg
+
+ def _load_model(self):
+ self._model = get_model(name=self._name, repo=self._repo)
+ if self._model is None:
+ raise LoadModelError("Failed to load model")
+ self._audio_channels = self._model.audio_channels
+ self._samplerate = self._model.samplerate
+
+ def _load_audio(self, track: Path):
+ errors = {}
+ wav = None
+
+ try:
+ wav = AudioFile(track).read(streams=0, samplerate=self._samplerate,
+ channels=self._audio_channels)
+ except FileNotFoundError:
+ errors["ffmpeg"] = "FFmpeg is not installed."
+ except subprocess.CalledProcessError:
+ errors["ffmpeg"] = "FFmpeg could not read the file."
+
+ if wav is None:
+ try:
+ wav, sr = ta.load(str(track))
+ except RuntimeError as err:
+ errors["torchaudio"] = err.args[0]
+ else:
+ wav = convert_audio(wav, sr, self._samplerate, self._audio_channels)
+
+ if wav is None:
+ raise LoadAudioError(
+ "\n".join(
+ "When trying to load using {}, got the following error: {}".format(
+ backend, error
+ )
+ for backend, error in errors.items()
+ )
+ )
+ return wav
+
+ def separate_tensor(
+ self, wav: th.Tensor, sr: Optional[int] = None
+ ) -> Tuple[th.Tensor, Dict[str, th.Tensor]]:
+ """
+ Separate a loaded tensor.
+
+ Parameters
+ ----------
+ wav: Waveform of the audio. Should have 2 dimensions, the first is each audio channel, \
+ while the second is the waveform of each channel. Type should be float32. \
+ e.g. `tuple(wav.shape) == (2, 884000)` means the audio has 2 channels.
+ sr: Sample rate of the original audio, the wave will be resampled if it doesn't match the \
+ model.
+
+ Returns
+ -------
+ A tuple, whose first element is the original wave and second element is a dict, whose keys
+ are the name of stems and values are separated waves. The original wave will have already
+ been resampled.
+
+ Notes
+ -----
+ Use this function with cautiousness. This function does not provide data verifying.
+ """
+ if sr is not None and sr != self.samplerate:
+ wav = convert_audio(wav, sr, self._samplerate, self._audio_channels)
+ ref = wav.mean(0)
+ wav -= ref.mean()
+ wav /= ref.std() + 1e-8
+ out = apply_model(
+ self._model,
+ wav[None],
+ segment=self._segment,
+ shifts=self._shifts,
+ split=self._split,
+ overlap=self._overlap,
+ device=self._device,
+ num_workers=self._jobs,
+ callback=self._callback,
+ callback_arg=_replace_dict(
+ self._callback_arg, ("audio_length", wav.shape[1])
+ ),
+ progress=self._progress,
+ )
+ if out is None:
+ raise KeyboardInterrupt
+ out *= ref.std() + 1e-8
+ out += ref.mean()
+ wav *= ref.std() + 1e-8
+ wav += ref.mean()
+ return (wav, dict(zip(self._model.sources, out[0])))
+
+ def separate_audio_file(self, file: Path):
+ """
+ Separate an audio file. The method will automatically read the file.
+
+ Parameters
+ ----------
+ wav: Path of the file to be separated.
+
+ Returns
+ -------
+ A tuple, whose first element is the original wave and second element is a dict, whose keys
+ are the name of stems and values are separated waves. The original wave will have already
+ been resampled.
+ """
+ return self.separate_tensor(self._load_audio(file), self.samplerate)
+
+ @property
+ def samplerate(self):
+ return self._samplerate
+
+ @property
+ def audio_channels(self):
+ return self._audio_channels
+
+ @property
+ def model(self):
+ return self._model
+
+
+def list_models(repo: Optional[Path] = None) -> Dict[str, Dict[str, Union[str, Path]]]:
+ """
+ List the available models. Please remember that not all the returned models can be
+ successfully loaded.
+
+ Parameters
+ ----------
+ repo: The repo whose models are to be listed.
+
+ Returns
+ -------
+ A dict with two keys ("single" for single models and "bag" for bag of models). The values are
+ lists whose components are strs.
+ """
+ model_repo: ModelOnlyRepo
+ if repo is None:
+ models = _parse_remote_files(REMOTE_ROOT / 'files.txt')
+ model_repo = RemoteRepo(models)
+ bag_repo = BagOnlyRepo(REMOTE_ROOT, model_repo)
+ else:
+ if not repo.is_dir():
+ fatal(f"{repo} must exist and be a directory.")
+ model_repo = LocalRepo(repo)
+ bag_repo = BagOnlyRepo(repo, model_repo)
+ return {"single": model_repo.list_model(), "bag": bag_repo.list_model()}
+
+
+if __name__ == "__main__":
+ # Test API functions
+ # two-stem not supported
+
+ from .separate import get_parser
+
+ args = get_parser().parse_args()
+ separator = Separator(
+ model=args.name,
+ repo=args.repo,
+ device=args.device,
+ shifts=args.shifts,
+ overlap=args.overlap,
+ split=args.split,
+ segment=args.segment,
+ jobs=args.jobs,
+ callback=print
+ )
+ out = args.out / args.name
+ out.mkdir(parents=True, exist_ok=True)
+ for file in args.tracks:
+ separated = separator.separate_audio_file(file)[1]
+ if args.mp3:
+ ext = "mp3"
+ elif args.flac:
+ ext = "flac"
+ else:
+ ext = "wav"
+ kwargs = {
+ "samplerate": separator.samplerate,
+ "bitrate": args.mp3_bitrate,
+ "clip": args.clip_mode,
+ "as_float": args.float32,
+ "bits_per_sample": 24 if args.int24 else 16,
+ }
+ for stem, source in separated.items():
+ stem = out / args.filename.format(
+ track=Path(file).name.rsplit(".", 1)[0],
+ trackext=Path(file).name.rsplit(".", 1)[-1],
+ stem=stem,
+ ext=ext,
+ )
+ stem.parent.mkdir(parents=True, exist_ok=True)
+ save_audio(source, str(stem), **kwargs)
diff --git a/AIMeiSheng/demucs/apply.py b/AIMeiSheng/demucs/apply.py
new file mode 100644
index 0000000..1540f3d
--- /dev/null
+++ b/AIMeiSheng/demucs/apply.py
@@ -0,0 +1,322 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""
+Code to apply a model to a mix. It will handle chunking with overlaps and
+inteprolation between chunks, as well as the "shift trick".
+"""
+from concurrent.futures import ThreadPoolExecutor
+import copy
+import random
+from threading import Lock
+import typing as tp
+
+import torch as th
+from torch import nn
+from torch.nn import functional as F
+import tqdm
+
+from .demucs import Demucs
+from .hdemucs import HDemucs
+from .htdemucs import HTDemucs
+from .utils import center_trim, DummyPoolExecutor
+
+Model = tp.Union[Demucs, HDemucs, HTDemucs]
+
+
+class BagOfModels(nn.Module):
+ def __init__(self, models: tp.List[Model],
+ weights: tp.Optional[tp.List[tp.List[float]]] = None,
+ segment: tp.Optional[float] = None):
+ """
+ Represents a bag of models with specific weights.
+ You should call `apply_model` rather than calling directly the forward here for
+ optimal performance.
+
+ Args:
+ models (list[nn.Module]): list of Demucs/HDemucs models.
+ weights (list[list[float]]): list of weights. If None, assumed to
+ be all ones, otherwise it should be a list of N list (N number of models),
+ each containing S floats (S number of sources).
+ segment (None or float): overrides the `segment` attribute of each model
+ (this is performed inplace, be careful is you reuse the models passed).
+ """
+ super().__init__()
+ assert len(models) > 0
+ first = models[0]
+ for other in models:
+ assert other.sources == first.sources
+ assert other.samplerate == first.samplerate
+ assert other.audio_channels == first.audio_channels
+ if segment is not None:
+ if not isinstance(other, HTDemucs) and segment > other.segment:
+ other.segment = segment
+
+ self.audio_channels = first.audio_channels
+ self.samplerate = first.samplerate
+ self.sources = first.sources
+ self.models = nn.ModuleList(models)
+
+ if weights is None:
+ weights = [[1. for _ in first.sources] for _ in models]
+ else:
+ assert len(weights) == len(models)
+ for weight in weights:
+ assert len(weight) == len(first.sources)
+ self.weights = weights
+
+ @property
+ def max_allowed_segment(self) -> float:
+ max_allowed_segment = float('inf')
+ for model in self.models:
+ if isinstance(model, HTDemucs):
+ max_allowed_segment = min(max_allowed_segment, float(model.segment))
+ return max_allowed_segment
+
+ def forward(self, x):
+ raise NotImplementedError("Call `apply_model` on this.")
+
+
+class TensorChunk:
+ def __init__(self, tensor, offset=0, length=None):
+ total_length = tensor.shape[-1]
+ assert offset >= 0
+ assert offset < total_length
+
+ if length is None:
+ length = total_length - offset
+ else:
+ length = min(total_length - offset, length)
+
+ if isinstance(tensor, TensorChunk):
+ self.tensor = tensor.tensor
+ self.offset = offset + tensor.offset
+ else:
+ self.tensor = tensor
+ self.offset = offset
+ self.length = length
+ self.device = tensor.device
+
+ @property
+ def shape(self):
+ shape = list(self.tensor.shape)
+ shape[-1] = self.length
+ return shape
+
+ def padded(self, target_length):
+ delta = target_length - self.length
+ total_length = self.tensor.shape[-1]
+ assert delta >= 0
+
+ start = self.offset - delta // 2
+ end = start + target_length
+
+ correct_start = max(0, start)
+ correct_end = min(total_length, end)
+
+ pad_left = correct_start - start
+ pad_right = end - correct_end
+
+ out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right))
+ assert out.shape[-1] == target_length
+ return out
+
+
+def tensor_chunk(tensor_or_chunk):
+ if isinstance(tensor_or_chunk, TensorChunk):
+ return tensor_or_chunk
+ else:
+ assert isinstance(tensor_or_chunk, th.Tensor)
+ return TensorChunk(tensor_or_chunk)
+
+
+def _replace_dict(_dict: tp.Optional[dict], *subs: tp.Tuple[tp.Hashable, tp.Any]) -> dict:
+ if _dict is None:
+ _dict = {}
+ else:
+ _dict = copy.copy(_dict)
+ for key, value in subs:
+ _dict[key] = value
+ return _dict
+
+
+def apply_model(model: tp.Union[BagOfModels, Model],
+ mix: tp.Union[th.Tensor, TensorChunk],
+ shifts: int = 1, split: bool = True,
+ overlap: float = 0.25, transition_power: float = 1.,
+ progress: bool = False, device=None,
+ num_workers: int = 0, segment: tp.Optional[float] = None,
+ pool=None, lock=None,
+ callback: tp.Optional[tp.Callable[[dict], None]] = None,
+ callback_arg: tp.Optional[dict] = None) -> th.Tensor:
+ """
+ Apply model to a given mixture.
+
+ Args:
+ shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec
+ and apply the oppositve shift to the output. This is repeated `shifts` time and
+ all predictions are averaged. This effectively makes the model time equivariant
+ and improves SDR by up to 0.2 points.
+ split (bool): if True, the input will be broken down in 8 seconds extracts
+ and predictions will be performed individually on each and concatenated.
+ Useful for model with large memory footprint like Tasnet.
+ progress (bool): if True, show a progress bar (requires split=True)
+ device (torch.device, str, or None): if provided, device on which to
+ execute the computation, otherwise `mix.device` is assumed.
+ When `device` is different from `mix.device`, only local computations will
+ be on `device`, while the entire tracks will be stored on `mix.device`.
+ num_workers (int): if non zero, device is 'cpu', how many threads to
+ use in parallel.
+ segment (float or None): override the model segment parameter.
+ """
+ if device is None:
+ device = mix.device
+ else:
+ device = th.device(device)
+ if pool is None:
+ if num_workers > 0 and device.type == 'cpu':
+ pool = ThreadPoolExecutor(num_workers)
+ else:
+ pool = DummyPoolExecutor()
+ if lock is None:
+ lock = Lock()
+ callback_arg = _replace_dict(
+ callback_arg, *{"model_idx_in_bag": 0, "shift_idx": 0, "segment_offset": 0}.items()
+ )
+ kwargs: tp.Dict[str, tp.Any] = {
+ 'shifts': shifts,
+ 'split': split,
+ 'overlap': overlap,
+ 'transition_power': transition_power,
+ 'progress': progress,
+ 'device': device,
+ 'pool': pool,
+ 'segment': segment,
+ 'lock': lock,
+ }
+ out: tp.Union[float, th.Tensor]
+ res: tp.Union[float, th.Tensor]
+ if isinstance(model, BagOfModels):
+ # Special treatment for bag of model.
+ # We explicitely apply multiple times `apply_model` so that the random shifts
+ # are different for each model.
+ estimates: tp.Union[float, th.Tensor] = 0.
+ totals = [0.] * len(model.sources)
+ callback_arg["models"] = len(model.models)
+ for sub_model, model_weights in zip(model.models, model.weights):
+ kwargs["callback"] = ((
+ lambda d, i=callback_arg["model_idx_in_bag"]: callback(
+ _replace_dict(d, ("model_idx_in_bag", i))) if callback else None)
+ )
+ original_model_device = next(iter(sub_model.parameters())).device
+ sub_model.to(device)
+
+ res = apply_model(sub_model, mix, **kwargs, callback_arg=callback_arg)
+ out = res
+ sub_model.to(original_model_device)
+ for k, inst_weight in enumerate(model_weights):
+ out[:, k, :, :] *= inst_weight
+ totals[k] += inst_weight
+ estimates += out
+ del out
+ callback_arg["model_idx_in_bag"] += 1
+
+ assert isinstance(estimates, th.Tensor)
+ for k in range(estimates.shape[1]):
+ estimates[:, k, :, :] /= totals[k]
+ return estimates
+
+ if "models" not in callback_arg:
+ callback_arg["models"] = 1
+ model.to(device)
+ model.eval()
+ assert transition_power >= 1, "transition_power < 1 leads to weird behavior."
+ batch, channels, length = mix.shape
+ if shifts:
+ kwargs['shifts'] = 0
+ max_shift = int(0.5 * model.samplerate)
+ mix = tensor_chunk(mix)
+ assert isinstance(mix, TensorChunk)
+ padded_mix = mix.padded(length + 2 * max_shift)
+ out = 0.
+ for shift_idx in range(shifts):
+ offset = random.randint(0, max_shift)
+ shifted = TensorChunk(padded_mix, offset, length + max_shift - offset)
+ kwargs["callback"] = (
+ (lambda d, i=shift_idx: callback(_replace_dict(d, ("shift_idx", i)))
+ if callback else None)
+ )
+ res = apply_model(model, shifted, **kwargs, callback_arg=callback_arg)
+ shifted_out = res
+ out += shifted_out[..., max_shift - offset:]
+ out /= shifts
+ assert isinstance(out, th.Tensor)
+ return out
+ elif split:
+ kwargs['split'] = False
+ out = th.zeros(batch, len(model.sources), channels, length, device=mix.device)
+ sum_weight = th.zeros(length, device=mix.device)
+ if segment is None:
+ segment = model.segment
+ assert segment is not None and segment > 0.
+ segment_length: int = int(model.samplerate * segment)
+ stride = int((1 - overlap) * segment_length)
+ offsets = range(0, length, stride)
+ scale = float(format(stride / model.samplerate, ".2f"))
+ # We start from a triangle shaped weight, with maximal weight in the middle
+ # of the segment. Then we normalize and take to the power `transition_power`.
+ # Large values of transition power will lead to sharper transitions.
+ weight = th.cat([th.arange(1, segment_length // 2 + 1, device=device),
+ th.arange(segment_length - segment_length // 2, 0, -1, device=device)])
+ assert len(weight) == segment_length
+ # If the overlap < 50%, this will translate to linear transition when
+ # transition_power is 1.
+ weight = (weight / weight.max())**transition_power
+ futures = []
+ for offset in offsets:
+ chunk = TensorChunk(mix, offset, segment_length)
+ future = pool.submit(apply_model, model, chunk, **kwargs, callback_arg=callback_arg,
+ callback=(lambda d, i=offset:
+ callback(_replace_dict(d, ("segment_offset", i)))
+ if callback else None))
+ futures.append((future, offset))
+ offset += segment_length
+ if progress:
+ futures = tqdm.tqdm(futures, unit_scale=scale, ncols=120, unit='seconds')
+ for future, offset in futures:
+ try:
+ chunk_out = future.result() # type: th.Tensor
+ except Exception:
+ pool.shutdown(wait=True, cancel_futures=True)
+ raise
+ chunk_length = chunk_out.shape[-1]
+ out[..., offset:offset + segment_length] += (
+ weight[:chunk_length] * chunk_out).to(mix.device)
+ sum_weight[offset:offset + segment_length] += weight[:chunk_length].to(mix.device)
+ assert sum_weight.min() > 0
+ out /= sum_weight
+ assert isinstance(out, th.Tensor)
+ return out
+ else:
+ valid_length: int
+ if isinstance(model, HTDemucs) and segment is not None:
+ valid_length = int(segment * model.samplerate)
+ elif hasattr(model, 'valid_length'):
+ valid_length = model.valid_length(length) # type: ignore
+ else:
+ valid_length = length
+ mix = tensor_chunk(mix)
+ assert isinstance(mix, TensorChunk)
+ padded_mix = mix.padded(valid_length).to(device)
+ with lock:
+ if callback is not None:
+ callback(_replace_dict(callback_arg, ("state", "start"))) # type: ignore
+ with th.no_grad():
+ out = model(padded_mix)
+ with lock:
+ if callback is not None:
+ callback(_replace_dict(callback_arg, ("state", "end"))) # type: ignore
+ assert isinstance(out, th.Tensor)
+ return center_trim(out, length)
diff --git a/AIMeiSheng/demucs/audio.py b/AIMeiSheng/demucs/audio.py
new file mode 100644
index 0000000..31b29b3
--- /dev/null
+++ b/AIMeiSheng/demucs/audio.py
@@ -0,0 +1,265 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+import json
+import subprocess as sp
+from pathlib import Path
+
+import lameenc
+import julius
+import numpy as np
+import torch
+import torchaudio as ta
+import typing as tp
+
+from .utils import temp_filenames
+
+
+def _read_info(path):
+ stdout_data = sp.check_output([
+ 'ffprobe', "-loglevel", "panic",
+ str(path), '-print_format', 'json', '-show_format', '-show_streams'
+ ])
+ return json.loads(stdout_data.decode('utf-8'))
+
+
+class AudioFile:
+ """
+ Allows to read audio from any format supported by ffmpeg, as well as resampling or
+ converting to mono on the fly. See :method:`read` for more details.
+ """
+ def __init__(self, path: Path):
+ self.path = Path(path)
+ self._info = None
+
+ def __repr__(self):
+ features = [("path", self.path)]
+ features.append(("samplerate", self.samplerate()))
+ features.append(("channels", self.channels()))
+ features.append(("streams", len(self)))
+ features_str = ", ".join(f"{name}={value}" for name, value in features)
+ return f"AudioFile({features_str})"
+
+ @property
+ def info(self):
+ if self._info is None:
+ self._info = _read_info(self.path)
+ return self._info
+
+ @property
+ def duration(self):
+ return float(self.info['format']['duration'])
+
+ @property
+ def _audio_streams(self):
+ return [
+ index for index, stream in enumerate(self.info["streams"])
+ if stream["codec_type"] == "audio"
+ ]
+
+ def __len__(self):
+ return len(self._audio_streams)
+
+ def channels(self, stream=0):
+ return int(self.info['streams'][self._audio_streams[stream]]['channels'])
+
+ def samplerate(self, stream=0):
+ return int(self.info['streams'][self._audio_streams[stream]]['sample_rate'])
+
+ def read(self,
+ seek_time=None,
+ duration=None,
+ streams=slice(None),
+ samplerate=None,
+ channels=None):
+ """
+ Slightly more efficient implementation than stempeg,
+ in particular, this will extract all stems at once
+ rather than having to loop over one file multiple times
+ for each stream.
+
+ Args:
+ seek_time (float): seek time in seconds or None if no seeking is needed.
+ duration (float): duration in seconds to extract or None to extract until the end.
+ streams (slice, int or list): streams to extract, can be a single int, a list or
+ a slice. If it is a slice or list, the output will be of size [S, C, T]
+ with S the number of streams, C the number of channels and T the number of samples.
+ If it is an int, the output will be [C, T].
+ samplerate (int): if provided, will resample on the fly. If None, no resampling will
+ be done. Original sampling rate can be obtained with :method:`samplerate`.
+ channels (int): if 1, will convert to mono. We do not rely on ffmpeg for that
+ as ffmpeg automatically scale by +3dB to conserve volume when playing on speakers.
+ See https://sound.stackexchange.com/a/42710.
+ Our definition of mono is simply the average of the two channels. Any other
+ value will be ignored.
+ """
+ streams = np.array(range(len(self)))[streams]
+ single = not isinstance(streams, np.ndarray)
+ if single:
+ streams = [streams]
+
+ if duration is None:
+ target_size = None
+ query_duration = None
+ else:
+ target_size = int((samplerate or self.samplerate()) * duration)
+ query_duration = float((target_size + 1) / (samplerate or self.samplerate()))
+
+ with temp_filenames(len(streams)) as filenames:
+ command = ['ffmpeg', '-y']
+ command += ['-loglevel', 'panic']
+ if seek_time:
+ command += ['-ss', str(seek_time)]
+ command += ['-i', str(self.path)]
+ for stream, filename in zip(streams, filenames):
+ command += ['-map', f'0:{self._audio_streams[stream]}']
+ if query_duration is not None:
+ command += ['-t', str(query_duration)]
+ command += ['-threads', '1']
+ command += ['-f', 'f32le']
+ if samplerate is not None:
+ command += ['-ar', str(samplerate)]
+ command += [filename]
+
+ sp.run(command, check=True)
+ wavs = []
+ for filename in filenames:
+ wav = np.fromfile(filename, dtype=np.float32)
+ wav = torch.from_numpy(wav)
+ wav = wav.view(-1, self.channels()).t()
+ if channels is not None:
+ wav = convert_audio_channels(wav, channels)
+ if target_size is not None:
+ wav = wav[..., :target_size]
+ wavs.append(wav)
+ wav = torch.stack(wavs, dim=0)
+ if single:
+ wav = wav[0]
+ return wav
+
+
+def convert_audio_channels(wav, channels=2):
+ """Convert audio to the given number of channels."""
+ *shape, src_channels, length = wav.shape
+ if src_channels == channels:
+ pass
+ elif channels == 1:
+ # Case 1:
+ # The caller asked 1-channel audio, but the stream have multiple
+ # channels, downmix all channels.
+ wav = wav.mean(dim=-2, keepdim=True)
+ elif src_channels == 1:
+ # Case 2:
+ # The caller asked for multiple channels, but the input file have
+ # one single channel, replicate the audio over all channels.
+ wav = wav.expand(*shape, channels, length)
+ elif src_channels >= channels:
+ # Case 3:
+ # The caller asked for multiple channels, and the input file have
+ # more channels than requested. In that case return the first channels.
+ wav = wav[..., :channels, :]
+ else:
+ # Case 4: What is a reasonable choice here?
+ raise ValueError('The audio file has less channels than requested but is not mono.')
+ return wav
+
+
+def convert_audio(wav, from_samplerate, to_samplerate, channels) -> torch.Tensor:
+ """Convert audio from a given samplerate to a target one and target number of channels."""
+ wav = convert_audio_channels(wav, channels)
+ return julius.resample_frac(wav, from_samplerate, to_samplerate)
+
+
+def i16_pcm(wav):
+ """Convert audio to 16 bits integer PCM format."""
+ if wav.dtype.is_floating_point:
+ return (wav.clamp_(-1, 1) * (2**15 - 1)).short()
+ else:
+ return wav
+
+
+def f32_pcm(wav):
+ """Convert audio to float 32 bits PCM format."""
+ if wav.dtype.is_floating_point:
+ return wav
+ else:
+ return wav.float() / (2**15 - 1)
+
+
+def as_dtype_pcm(wav, dtype):
+ """Convert audio to either f32 pcm or i16 pcm depending on the given dtype."""
+ if wav.dtype.is_floating_point:
+ return f32_pcm(wav)
+ else:
+ return i16_pcm(wav)
+
+
+def encode_mp3(wav, path, samplerate=44100, bitrate=320, quality=2, verbose=False):
+ """Save given audio as mp3. This should work on all OSes."""
+ C, T = wav.shape
+ wav = i16_pcm(wav)
+ encoder = lameenc.Encoder()
+ encoder.set_bit_rate(bitrate)
+ encoder.set_in_sample_rate(samplerate)
+ encoder.set_channels(C)
+ encoder.set_quality(quality) # 2-highest, 7-fastest
+ if not verbose:
+ encoder.silence()
+ wav = wav.data.cpu()
+ wav = wav.transpose(0, 1).numpy()
+ mp3_data = encoder.encode(wav.tobytes())
+ mp3_data += encoder.flush()
+ with open(path, "wb") as f:
+ f.write(mp3_data)
+
+
+def prevent_clip(wav, mode='rescale'):
+ """
+ different strategies for avoiding raw clipping.
+ """
+ if mode is None or mode == 'none':
+ return wav
+ assert wav.dtype.is_floating_point, "too late for clipping"
+ if mode == 'rescale':
+ wav = wav / max(1.01 * wav.abs().max(), 1)
+ elif mode == 'clamp':
+ wav = wav.clamp(-0.99, 0.99)
+ elif mode == 'tanh':
+ wav = torch.tanh(wav)
+ else:
+ raise ValueError(f"Invalid mode {mode}")
+ return wav
+
+
+def save_audio(wav: torch.Tensor,
+ path: tp.Union[str, Path],
+ samplerate: int,
+ bitrate: int = 320,
+ clip: tp.Literal["rescale", "clamp", "tanh", "none"] = 'rescale',
+ bits_per_sample: tp.Literal[16, 24, 32] = 16,
+ as_float: bool = False,
+ preset: tp.Literal[2, 3, 4, 5, 6, 7] = 2):
+ """Save audio file, automatically preventing clipping if necessary
+ based on the given `clip` strategy. If the path ends in `.mp3`, this
+ will save as mp3 with the given `bitrate`. Use `preset` to set mp3 quality:
+ 2 for highest quality, 7 for fastest speed
+ """
+ wav = prevent_clip(wav, mode=clip)
+ path = Path(path)
+ suffix = path.suffix.lower()
+ if suffix == ".mp3":
+ encode_mp3(wav, path, samplerate, bitrate, preset, verbose=True)
+ elif suffix == ".wav":
+ if as_float:
+ bits_per_sample = 32
+ encoding = 'PCM_F'
+ else:
+ encoding = 'PCM_S'
+ ta.save(str(path), wav, sample_rate=samplerate,
+ encoding=encoding, bits_per_sample=bits_per_sample)
+ elif suffix == ".flac":
+ ta.save(str(path), wav, sample_rate=samplerate, bits_per_sample=bits_per_sample)
+ else:
+ raise ValueError(f"Invalid suffix for path: {suffix}")
diff --git a/AIMeiSheng/demucs/augment.py b/AIMeiSheng/demucs/augment.py
new file mode 100644
index 0000000..6dab7f1
--- /dev/null
+++ b/AIMeiSheng/demucs/augment.py
@@ -0,0 +1,111 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Data augmentations.
+"""
+
+import random
+import torch as th
+from torch import nn
+
+
+class Shift(nn.Module):
+ """
+ Randomly shift audio in time by up to `shift` samples.
+ """
+ def __init__(self, shift=8192, same=False):
+ super().__init__()
+ self.shift = shift
+ self.same = same
+
+ def forward(self, wav):
+ batch, sources, channels, time = wav.size()
+ length = time - self.shift
+ if self.shift > 0:
+ if not self.training:
+ wav = wav[..., :length]
+ else:
+ srcs = 1 if self.same else sources
+ offsets = th.randint(self.shift, [batch, srcs, 1, 1], device=wav.device)
+ offsets = offsets.expand(-1, sources, channels, -1)
+ indexes = th.arange(length, device=wav.device)
+ wav = wav.gather(3, indexes + offsets)
+ return wav
+
+
+class FlipChannels(nn.Module):
+ """
+ Flip left-right channels.
+ """
+ def forward(self, wav):
+ batch, sources, channels, time = wav.size()
+ if self.training and wav.size(2) == 2:
+ left = th.randint(2, (batch, sources, 1, 1), device=wav.device)
+ left = left.expand(-1, -1, -1, time)
+ right = 1 - left
+ wav = th.cat([wav.gather(2, left), wav.gather(2, right)], dim=2)
+ return wav
+
+
+class FlipSign(nn.Module):
+ """
+ Random sign flip.
+ """
+ def forward(self, wav):
+ batch, sources, channels, time = wav.size()
+ if self.training:
+ signs = th.randint(2, (batch, sources, 1, 1), device=wav.device, dtype=th.float32)
+ wav = wav * (2 * signs - 1)
+ return wav
+
+
+class Remix(nn.Module):
+ """
+ Shuffle sources to make new mixes.
+ """
+ def __init__(self, proba=1, group_size=4):
+ """
+ Shuffle sources within one batch.
+ Each batch is divided into groups of size `group_size` and shuffling is done within
+ each group separatly. This allow to keep the same probability distribution no matter
+ the number of GPUs. Without this grouping, using more GPUs would lead to a higher
+ probability of keeping two sources from the same track together which can impact
+ performance.
+ """
+ super().__init__()
+ self.proba = proba
+ self.group_size = group_size
+
+ def forward(self, wav):
+ batch, streams, channels, time = wav.size()
+ device = wav.device
+
+ if self.training and random.random() < self.proba:
+ group_size = self.group_size or batch
+ if batch % group_size != 0:
+ raise ValueError(f"Batch size {batch} must be divisible by group size {group_size}")
+ groups = batch // group_size
+ wav = wav.view(groups, group_size, streams, channels, time)
+ permutations = th.argsort(th.rand(groups, group_size, streams, 1, 1, device=device),
+ dim=1)
+ wav = wav.gather(1, permutations.expand(-1, -1, -1, channels, time))
+ wav = wav.view(batch, streams, channels, time)
+ return wav
+
+
+class Scale(nn.Module):
+ def __init__(self, proba=1., min=0.25, max=1.25):
+ super().__init__()
+ self.proba = proba
+ self.min = min
+ self.max = max
+
+ def forward(self, wav):
+ batch, streams, channels, time = wav.size()
+ device = wav.device
+ if self.training and random.random() < self.proba:
+ scales = th.empty(batch, streams, 1, 1, device=device).uniform_(self.min, self.max)
+ wav *= scales
+ return wav
diff --git a/AIMeiSheng/demucs/demucs.py b/AIMeiSheng/demucs/demucs.py
new file mode 100644
index 0000000..f6a4305
--- /dev/null
+++ b/AIMeiSheng/demucs/demucs.py
@@ -0,0 +1,447 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+import typing as tp
+
+import julius
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from .states import capture_init
+from .utils import center_trim, unfold
+from .transformer import LayerScale
+
+
+class BLSTM(nn.Module):
+ """
+ BiLSTM with same hidden units as input dim.
+ If `max_steps` is not None, input will be splitting in overlapping
+ chunks and the LSTM applied separately on each chunk.
+ """
+ def __init__(self, dim, layers=1, max_steps=None, skip=False):
+ super().__init__()
+ assert max_steps is None or max_steps % 4 == 0
+ self.max_steps = max_steps
+ self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
+ self.linear = nn.Linear(2 * dim, dim)
+ self.skip = skip
+
+ def forward(self, x):
+ B, C, T = x.shape
+ y = x
+ framed = False
+ if self.max_steps is not None and T > self.max_steps:
+ width = self.max_steps
+ stride = width // 2
+ frames = unfold(x, width, stride)
+ nframes = frames.shape[2]
+ framed = True
+ x = frames.permute(0, 2, 1, 3).reshape(-1, C, width)
+
+ x = x.permute(2, 0, 1)
+
+ x = self.lstm(x)[0]
+ x = self.linear(x)
+ x = x.permute(1, 2, 0)
+ if framed:
+ out = []
+ frames = x.reshape(B, -1, C, width)
+ limit = stride // 2
+ for k in range(nframes):
+ if k == 0:
+ out.append(frames[:, k, :, :-limit])
+ elif k == nframes - 1:
+ out.append(frames[:, k, :, limit:])
+ else:
+ out.append(frames[:, k, :, limit:-limit])
+ out = torch.cat(out, -1)
+ out = out[..., :T]
+ x = out
+ if self.skip:
+ x = x + y
+ return x
+
+
+def rescale_conv(conv, reference):
+ """Rescale initial weight scale. It is unclear why it helps but it certainly does.
+ """
+ std = conv.weight.std().detach()
+ scale = (std / reference)**0.5
+ conv.weight.data /= scale
+ if conv.bias is not None:
+ conv.bias.data /= scale
+
+
+def rescale_module(module, reference):
+ for sub in module.modules():
+ if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d, nn.Conv2d, nn.ConvTranspose2d)):
+ rescale_conv(sub, reference)
+
+
+class DConv(nn.Module):
+ """
+ New residual branches in each encoder layer.
+ This alternates dilated convolutions, potentially with LSTMs and attention.
+ Also before entering each residual branch, dimension is projected on a smaller subspace,
+ e.g. of dim `channels // compress`.
+ """
+ def __init__(self, channels: int, compress: float = 4, depth: int = 2, init: float = 1e-4,
+ norm=True, attn=False, heads=4, ndecay=4, lstm=False, gelu=True,
+ kernel=3, dilate=True):
+ """
+ Args:
+ channels: input/output channels for residual branch.
+ compress: amount of channel compression inside the branch.
+ depth: number of layers in the residual branch. Each layer has its own
+ projection, and potentially LSTM and attention.
+ init: initial scale for LayerNorm.
+ norm: use GroupNorm.
+ attn: use LocalAttention.
+ heads: number of heads for the LocalAttention.
+ ndecay: number of decay controls in the LocalAttention.
+ lstm: use LSTM.
+ gelu: Use GELU activation.
+ kernel: kernel size for the (dilated) convolutions.
+ dilate: if true, use dilation, increasing with the depth.
+ """
+
+ super().__init__()
+ assert kernel % 2 == 1
+ self.channels = channels
+ self.compress = compress
+ self.depth = abs(depth)
+ dilate = depth > 0
+
+ norm_fn: tp.Callable[[int], nn.Module]
+ norm_fn = lambda d: nn.Identity() # noqa
+ if norm:
+ norm_fn = lambda d: nn.GroupNorm(1, d) # noqa
+
+ hidden = int(channels / compress)
+
+ act: tp.Type[nn.Module]
+ if gelu:
+ act = nn.GELU
+ else:
+ act = nn.ReLU
+
+ self.layers = nn.ModuleList([])
+ for d in range(self.depth):
+ dilation = 2 ** d if dilate else 1
+ padding = dilation * (kernel // 2)
+ mods = [
+ nn.Conv1d(channels, hidden, kernel, dilation=dilation, padding=padding),
+ norm_fn(hidden), act(),
+ nn.Conv1d(hidden, 2 * channels, 1),
+ norm_fn(2 * channels), nn.GLU(1),
+ LayerScale(channels, init),
+ ]
+ if attn:
+ mods.insert(3, LocalState(hidden, heads=heads, ndecay=ndecay))
+ if lstm:
+ mods.insert(3, BLSTM(hidden, layers=2, max_steps=200, skip=True))
+ layer = nn.Sequential(*mods)
+ self.layers.append(layer)
+
+ def forward(self, x):
+ for layer in self.layers:
+ x = x + layer(x)
+ return x
+
+
+class LocalState(nn.Module):
+ """Local state allows to have attention based only on data (no positional embedding),
+ but while setting a constraint on the time window (e.g. decaying penalty term).
+
+ Also a failed experiments with trying to provide some frequency based attention.
+ """
+ def __init__(self, channels: int, heads: int = 4, nfreqs: int = 0, ndecay: int = 4):
+ super().__init__()
+ assert channels % heads == 0, (channels, heads)
+ self.heads = heads
+ self.nfreqs = nfreqs
+ self.ndecay = ndecay
+ self.content = nn.Conv1d(channels, channels, 1)
+ self.query = nn.Conv1d(channels, channels, 1)
+ self.key = nn.Conv1d(channels, channels, 1)
+ if nfreqs:
+ self.query_freqs = nn.Conv1d(channels, heads * nfreqs, 1)
+ if ndecay:
+ self.query_decay = nn.Conv1d(channels, heads * ndecay, 1)
+ # Initialize decay close to zero (there is a sigmoid), for maximum initial window.
+ self.query_decay.weight.data *= 0.01
+ assert self.query_decay.bias is not None # stupid type checker
+ self.query_decay.bias.data[:] = -2
+ self.proj = nn.Conv1d(channels + heads * nfreqs, channels, 1)
+
+ def forward(self, x):
+ B, C, T = x.shape
+ heads = self.heads
+ indexes = torch.arange(T, device=x.device, dtype=x.dtype)
+ # left index are keys, right index are queries
+ delta = indexes[:, None] - indexes[None, :]
+
+ queries = self.query(x).view(B, heads, -1, T)
+ keys = self.key(x).view(B, heads, -1, T)
+ # t are keys, s are queries
+ dots = torch.einsum("bhct,bhcs->bhts", keys, queries)
+ dots /= keys.shape[2]**0.5
+ if self.nfreqs:
+ periods = torch.arange(1, self.nfreqs + 1, device=x.device, dtype=x.dtype)
+ freq_kernel = torch.cos(2 * math.pi * delta / periods.view(-1, 1, 1))
+ freq_q = self.query_freqs(x).view(B, heads, -1, T) / self.nfreqs ** 0.5
+ dots += torch.einsum("fts,bhfs->bhts", freq_kernel, freq_q)
+ if self.ndecay:
+ decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype)
+ decay_q = self.query_decay(x).view(B, heads, -1, T)
+ decay_q = torch.sigmoid(decay_q) / 2
+ decay_kernel = - decays.view(-1, 1, 1) * delta.abs() / self.ndecay**0.5
+ dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q)
+
+ # Kill self reference.
+ dots.masked_fill_(torch.eye(T, device=dots.device, dtype=torch.bool), -100)
+ weights = torch.softmax(dots, dim=2)
+
+ content = self.content(x).view(B, heads, -1, T)
+ result = torch.einsum("bhts,bhct->bhcs", weights, content)
+ if self.nfreqs:
+ time_sig = torch.einsum("bhts,fts->bhfs", weights, freq_kernel)
+ result = torch.cat([result, time_sig], 2)
+ result = result.reshape(B, -1, T)
+ return x + self.proj(result)
+
+
+class Demucs(nn.Module):
+ @capture_init
+ def __init__(self,
+ sources,
+ # Channels
+ audio_channels=2,
+ channels=64,
+ growth=2.,
+ # Main structure
+ depth=6,
+ rewrite=True,
+ lstm_layers=0,
+ # Convolutions
+ kernel_size=8,
+ stride=4,
+ context=1,
+ # Activations
+ gelu=True,
+ glu=True,
+ # Normalization
+ norm_starts=4,
+ norm_groups=4,
+ # DConv residual branch
+ dconv_mode=1,
+ dconv_depth=2,
+ dconv_comp=4,
+ dconv_attn=4,
+ dconv_lstm=4,
+ dconv_init=1e-4,
+ # Pre/post processing
+ normalize=True,
+ resample=True,
+ # Weight init
+ rescale=0.1,
+ # Metadata
+ samplerate=44100,
+ segment=4 * 10):
+ """
+ Args:
+ sources (list[str]): list of source names
+ audio_channels (int): stereo or mono
+ channels (int): first convolution channels
+ depth (int): number of encoder/decoder layers
+ growth (float): multiply (resp divide) number of channels by that
+ for each layer of the encoder (resp decoder)
+ depth (int): number of layers in the encoder and in the decoder.
+ rewrite (bool): add 1x1 convolution to each layer.
+ lstm_layers (int): number of lstm layers, 0 = no lstm. Deactivated
+ by default, as this is now replaced by the smaller and faster small LSTMs
+ in the DConv branches.
+ kernel_size (int): kernel size for convolutions
+ stride (int): stride for convolutions
+ context (int): kernel size of the convolution in the
+ decoder before the transposed convolution. If > 1,
+ will provide some context from neighboring time steps.
+ gelu: use GELU activation function.
+ glu (bool): use glu instead of ReLU for the 1x1 rewrite conv.
+ norm_starts: layer at which group norm starts being used.
+ decoder layers are numbered in reverse order.
+ norm_groups: number of groups for group norm.
+ dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
+ dconv_depth: depth of residual DConv branch.
+ dconv_comp: compression of DConv branch.
+ dconv_attn: adds attention layers in DConv branch starting at this layer.
+ dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
+ dconv_init: initial scale for the DConv branch LayerScale.
+ normalize (bool): normalizes the input audio on the fly, and scales back
+ the output by the same amount.
+ resample (bool): upsample x2 the input and downsample /2 the output.
+ rescale (float): rescale initial weights of convolutions
+ to get their standard deviation closer to `rescale`.
+ samplerate (int): stored as meta information for easing
+ future evaluations of the model.
+ segment (float): duration of the chunks of audio to ideally evaluate the model on.
+ This is used by `demucs.apply.apply_model`.
+ """
+
+ super().__init__()
+ self.audio_channels = audio_channels
+ self.sources = sources
+ self.kernel_size = kernel_size
+ self.context = context
+ self.stride = stride
+ self.depth = depth
+ self.resample = resample
+ self.channels = channels
+ self.normalize = normalize
+ self.samplerate = samplerate
+ self.segment = segment
+ self.encoder = nn.ModuleList()
+ self.decoder = nn.ModuleList()
+ self.skip_scales = nn.ModuleList()
+
+ if glu:
+ activation = nn.GLU(dim=1)
+ ch_scale = 2
+ else:
+ activation = nn.ReLU()
+ ch_scale = 1
+ if gelu:
+ act2 = nn.GELU
+ else:
+ act2 = nn.ReLU
+
+ in_channels = audio_channels
+ padding = 0
+ for index in range(depth):
+ norm_fn = lambda d: nn.Identity() # noqa
+ if index >= norm_starts:
+ norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
+
+ encode = []
+ encode += [
+ nn.Conv1d(in_channels, channels, kernel_size, stride),
+ norm_fn(channels),
+ act2(),
+ ]
+ attn = index >= dconv_attn
+ lstm = index >= dconv_lstm
+ if dconv_mode & 1:
+ encode += [DConv(channels, depth=dconv_depth, init=dconv_init,
+ compress=dconv_comp, attn=attn, lstm=lstm)]
+ if rewrite:
+ encode += [
+ nn.Conv1d(channels, ch_scale * channels, 1),
+ norm_fn(ch_scale * channels), activation]
+ self.encoder.append(nn.Sequential(*encode))
+
+ decode = []
+ if index > 0:
+ out_channels = in_channels
+ else:
+ out_channels = len(self.sources) * audio_channels
+ if rewrite:
+ decode += [
+ nn.Conv1d(channels, ch_scale * channels, 2 * context + 1, padding=context),
+ norm_fn(ch_scale * channels), activation]
+ if dconv_mode & 2:
+ decode += [DConv(channels, depth=dconv_depth, init=dconv_init,
+ compress=dconv_comp, attn=attn, lstm=lstm)]
+ decode += [nn.ConvTranspose1d(channels, out_channels,
+ kernel_size, stride, padding=padding)]
+ if index > 0:
+ decode += [norm_fn(out_channels), act2()]
+ self.decoder.insert(0, nn.Sequential(*decode))
+ in_channels = channels
+ channels = int(growth * channels)
+
+ channels = in_channels
+ if lstm_layers:
+ self.lstm = BLSTM(channels, lstm_layers)
+ else:
+ self.lstm = None
+
+ if rescale:
+ rescale_module(self, reference=rescale)
+
+ def valid_length(self, length):
+ """
+ Return the nearest valid length to use with the model so that
+ there is no time steps left over in a convolution, e.g. for all
+ layers, size of the input - kernel_size % stride = 0.
+
+ Note that input are automatically padded if necessary to ensure that the output
+ has the same length as the input.
+ """
+ if self.resample:
+ length *= 2
+
+ for _ in range(self.depth):
+ length = math.ceil((length - self.kernel_size) / self.stride) + 1
+ length = max(1, length)
+
+ for idx in range(self.depth):
+ length = (length - 1) * self.stride + self.kernel_size
+
+ if self.resample:
+ length = math.ceil(length / 2)
+ return int(length)
+
+ def forward(self, mix):
+ x = mix
+ length = x.shape[-1]
+
+ if self.normalize:
+ mono = mix.mean(dim=1, keepdim=True)
+ mean = mono.mean(dim=-1, keepdim=True)
+ std = mono.std(dim=-1, keepdim=True)
+ x = (x - mean) / (1e-5 + std)
+ else:
+ mean = 0
+ std = 1
+
+ delta = self.valid_length(length) - length
+ x = F.pad(x, (delta // 2, delta - delta // 2))
+
+ if self.resample:
+ x = julius.resample_frac(x, 1, 2)
+
+ saved = []
+ for encode in self.encoder:
+ x = encode(x)
+ saved.append(x)
+
+ if self.lstm:
+ x = self.lstm(x)
+
+ for decode in self.decoder:
+ skip = saved.pop(-1)
+ skip = center_trim(skip, x)
+ x = decode(x + skip)
+
+ if self.resample:
+ x = julius.resample_frac(x, 2, 1)
+ x = x * std + mean
+ x = center_trim(x, length)
+ x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1))
+ return x
+
+ def load_state_dict(self, state, strict=True):
+ # fix a mismatch with previous generation Demucs models.
+ for idx in range(self.depth):
+ for a in ['encoder', 'decoder']:
+ for b in ['bias', 'weight']:
+ new = f'{a}.{idx}.3.{b}'
+ old = f'{a}.{idx}.2.{b}'
+ if old in state and new not in state:
+ state[new] = state.pop(old)
+ super().load_state_dict(state, strict=strict)
diff --git a/AIMeiSheng/demucs/distrib.py b/AIMeiSheng/demucs/distrib.py
new file mode 100644
index 0000000..dc1576c
--- /dev/null
+++ b/AIMeiSheng/demucs/distrib.py
@@ -0,0 +1,100 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Distributed training utilities.
+"""
+import logging
+import pickle
+
+import numpy as np
+import torch
+from torch.utils.data.distributed import DistributedSampler
+from torch.utils.data import DataLoader, Subset
+from torch.nn.parallel.distributed import DistributedDataParallel
+
+from dora import distrib as dora_distrib
+
+logger = logging.getLogger(__name__)
+rank = 0
+world_size = 1
+
+
+def init():
+ global rank, world_size
+ if not torch.distributed.is_initialized():
+ dora_distrib.init()
+ rank = dora_distrib.rank()
+ world_size = dora_distrib.world_size()
+
+
+def average(metrics, count=1.):
+ if isinstance(metrics, dict):
+ keys, values = zip(*sorted(metrics.items()))
+ values = average(values, count)
+ return dict(zip(keys, values))
+ if world_size == 1:
+ return metrics
+ tensor = torch.tensor(list(metrics) + [1], device='cuda', dtype=torch.float32)
+ tensor *= count
+ torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM)
+ return (tensor[:-1] / tensor[-1]).cpu().numpy().tolist()
+
+
+def wrap(model):
+ if world_size == 1:
+ return model
+ else:
+ return DistributedDataParallel(
+ model,
+ # find_unused_parameters=True,
+ device_ids=[torch.cuda.current_device()],
+ output_device=torch.cuda.current_device())
+
+
+def barrier():
+ if world_size > 1:
+ torch.distributed.barrier()
+
+
+def share(obj=None, src=0):
+ if world_size == 1:
+ return obj
+ size = torch.empty(1, device='cuda', dtype=torch.long)
+ if rank == src:
+ dump = pickle.dumps(obj)
+ size[0] = len(dump)
+ torch.distributed.broadcast(size, src=src)
+ # size variable is now set to the length of pickled obj in all processes
+
+ if rank == src:
+ buffer = torch.from_numpy(np.frombuffer(dump, dtype=np.uint8).copy()).cuda()
+ else:
+ buffer = torch.empty(size[0].item(), device='cuda', dtype=torch.uint8)
+ torch.distributed.broadcast(buffer, src=src)
+ # buffer variable is now set to pickled obj in all processes
+
+ if rank != src:
+ obj = pickle.loads(buffer.cpu().numpy().tobytes())
+ logger.debug(f"Shared object of size {len(buffer)}")
+ return obj
+
+
+def loader(dataset, *args, shuffle=False, klass=DataLoader, **kwargs):
+ """
+ Create a dataloader properly in case of distributed training.
+ If a gradient is going to be computed you must set `shuffle=True`.
+ """
+ if world_size == 1:
+ return klass(dataset, *args, shuffle=shuffle, **kwargs)
+
+ if shuffle:
+ # train means we will compute backward, we use DistributedSampler
+ sampler = DistributedSampler(dataset)
+ # We ignore shuffle, DistributedSampler already shuffles
+ return klass(dataset, *args, **kwargs, sampler=sampler)
+ else:
+ # We make a manual shard, as DistributedSampler otherwise replicate some examples
+ dataset = Subset(dataset, list(range(rank, len(dataset), world_size)))
+ return klass(dataset, *args, shuffle=shuffle, **kwargs)
diff --git a/AIMeiSheng/demucs/ema.py b/AIMeiSheng/demucs/ema.py
new file mode 100644
index 0000000..101bee0
--- /dev/null
+++ b/AIMeiSheng/demucs/ema.py
@@ -0,0 +1,66 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Inspired from https://github.com/rwightman/pytorch-image-models
+from contextlib import contextmanager
+
+import torch
+
+from .states import swap_state
+
+
+class ModelEMA:
+ """
+ Perform EMA on a model. You can switch to the EMA weights temporarily
+ with the `swap` method.
+
+ ema = ModelEMA(model)
+ with ema.swap():
+ # compute valid metrics with averaged model.
+ """
+ def __init__(self, model, decay=0.9999, unbias=True, device='cpu'):
+ self.decay = decay
+ self.model = model
+ self.state = {}
+ self.count = 0
+ self.device = device
+ self.unbias = unbias
+
+ self._init()
+
+ def _init(self):
+ for key, val in self.model.state_dict().items():
+ if val.dtype != torch.float32:
+ continue
+ device = self.device or val.device
+ if key not in self.state:
+ self.state[key] = val.detach().to(device, copy=True)
+
+ def update(self):
+ if self.unbias:
+ self.count = self.count * self.decay + 1
+ w = 1 / self.count
+ else:
+ w = 1 - self.decay
+ for key, val in self.model.state_dict().items():
+ if val.dtype != torch.float32:
+ continue
+ device = self.device or val.device
+ self.state[key].mul_(1 - w)
+ self.state[key].add_(val.detach().to(device), alpha=w)
+
+ @contextmanager
+ def swap(self):
+ with swap_state(self.model, self.state):
+ yield
+
+ def state_dict(self):
+ return {'state': self.state, 'count': self.count}
+
+ def load_state_dict(self, state):
+ self.count = state['count']
+ for k, v in state['state'].items():
+ self.state[k].copy_(v)
diff --git a/AIMeiSheng/demucs/evaluate.py b/AIMeiSheng/demucs/evaluate.py
new file mode 100644
index 0000000..fa2ff45
--- /dev/null
+++ b/AIMeiSheng/demucs/evaluate.py
@@ -0,0 +1,174 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Test time evaluation, either using the original SDR from [Vincent et al. 2006]
+or the newest SDR definition from the MDX 2021 competition (this one will
+be reported as `nsdr` for `new sdr`).
+"""
+
+from concurrent import futures
+import logging
+
+from dora.log import LogProgress
+import numpy as np
+import musdb
+import museval
+import torch as th
+
+from .apply import apply_model
+from .audio import convert_audio, save_audio
+from . import distrib
+from .utils import DummyPoolExecutor
+
+
+logger = logging.getLogger(__name__)
+
+
+def new_sdr(references, estimates):
+ """
+ Compute the SDR according to the MDX challenge definition.
+ Adapted from AIcrowd/music-demixing-challenge-starter-kit (MIT license)
+ """
+ assert references.dim() == 4
+ assert estimates.dim() == 4
+ delta = 1e-7 # avoid numerical errors
+ num = th.sum(th.square(references), dim=(2, 3))
+ den = th.sum(th.square(references - estimates), dim=(2, 3))
+ num += delta
+ den += delta
+ scores = 10 * th.log10(num / den)
+ return scores
+
+
+def eval_track(references, estimates, win, hop, compute_sdr=True):
+ references = references.transpose(1, 2).double()
+ estimates = estimates.transpose(1, 2).double()
+
+ new_scores = new_sdr(references.cpu()[None], estimates.cpu()[None])[0]
+
+ if not compute_sdr:
+ return None, new_scores
+ else:
+ references = references.numpy()
+ estimates = estimates.numpy()
+ scores = museval.metrics.bss_eval(
+ references, estimates,
+ compute_permutation=False,
+ window=win,
+ hop=hop,
+ framewise_filters=False,
+ bsseval_sources_version=False)[:-1]
+ return scores, new_scores
+
+
+def evaluate(solver, compute_sdr=False):
+ """
+ Evaluate model using museval.
+ compute_sdr=False means using only the MDX definition of the SDR, which
+ is much faster to evaluate.
+ """
+
+ args = solver.args
+
+ output_dir = solver.folder / "results"
+ output_dir.mkdir(exist_ok=True, parents=True)
+ json_folder = solver.folder / "results/test"
+ json_folder.mkdir(exist_ok=True, parents=True)
+
+ # we load tracks from the original musdb set
+ if args.test.nonhq is None:
+ test_set = musdb.DB(args.dset.musdb, subsets=["test"], is_wav=True)
+ else:
+ test_set = musdb.DB(args.test.nonhq, subsets=["test"], is_wav=False)
+ src_rate = args.dset.musdb_samplerate
+
+ eval_device = 'cpu'
+
+ model = solver.model
+ win = int(1. * model.samplerate)
+ hop = int(1. * model.samplerate)
+
+ indexes = range(distrib.rank, len(test_set), distrib.world_size)
+ indexes = LogProgress(logger, indexes, updates=args.misc.num_prints,
+ name='Eval')
+ pendings = []
+
+ pool = futures.ProcessPoolExecutor if args.test.workers else DummyPoolExecutor
+ with pool(args.test.workers) as pool:
+ for index in indexes:
+ track = test_set.tracks[index]
+
+ mix = th.from_numpy(track.audio).t().float()
+ if mix.dim() == 1:
+ mix = mix[None]
+ mix = mix.to(solver.device)
+ ref = mix.mean(dim=0) # mono mixture
+ mix = (mix - ref.mean()) / ref.std()
+ mix = convert_audio(mix, src_rate, model.samplerate, model.audio_channels)
+ estimates = apply_model(model, mix[None],
+ shifts=args.test.shifts, split=args.test.split,
+ overlap=args.test.overlap)[0]
+ estimates = estimates * ref.std() + ref.mean()
+ estimates = estimates.to(eval_device)
+
+ references = th.stack(
+ [th.from_numpy(track.targets[name].audio).t() for name in model.sources])
+ if references.dim() == 2:
+ references = references[:, None]
+ references = references.to(eval_device)
+ references = convert_audio(references, src_rate,
+ model.samplerate, model.audio_channels)
+ if args.test.save:
+ folder = solver.folder / "wav" / track.name
+ folder.mkdir(exist_ok=True, parents=True)
+ for name, estimate in zip(model.sources, estimates):
+ save_audio(estimate.cpu(), folder / (name + ".mp3"), model.samplerate)
+
+ pendings.append((track.name, pool.submit(
+ eval_track, references, estimates, win=win, hop=hop, compute_sdr=compute_sdr)))
+
+ pendings = LogProgress(logger, pendings, updates=args.misc.num_prints,
+ name='Eval (BSS)')
+ tracks = {}
+ for track_name, pending in pendings:
+ pending = pending.result()
+ scores, nsdrs = pending
+ tracks[track_name] = {}
+ for idx, target in enumerate(model.sources):
+ tracks[track_name][target] = {'nsdr': [float(nsdrs[idx])]}
+ if scores is not None:
+ (sdr, isr, sir, sar) = scores
+ for idx, target in enumerate(model.sources):
+ values = {
+ "SDR": sdr[idx].tolist(),
+ "SIR": sir[idx].tolist(),
+ "ISR": isr[idx].tolist(),
+ "SAR": sar[idx].tolist()
+ }
+ tracks[track_name][target].update(values)
+
+ all_tracks = {}
+ for src in range(distrib.world_size):
+ all_tracks.update(distrib.share(tracks, src))
+
+ result = {}
+ metric_names = next(iter(all_tracks.values()))[model.sources[0]]
+ for metric_name in metric_names:
+ avg = 0
+ avg_of_medians = 0
+ for source in model.sources:
+ medians = [
+ np.nanmedian(all_tracks[track][source][metric_name])
+ for track in all_tracks.keys()]
+ mean = np.mean(medians)
+ median = np.median(medians)
+ result[metric_name.lower() + "_" + source] = mean
+ result[metric_name.lower() + "_med" + "_" + source] = median
+ avg += mean / len(model.sources)
+ avg_of_medians += median / len(model.sources)
+ result[metric_name.lower()] = avg
+ result[metric_name.lower() + "_med"] = avg_of_medians
+ return result
diff --git a/AIMeiSheng/demucs/grids/__init__.py b/AIMeiSheng/demucs/grids/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/AIMeiSheng/demucs/grids/_explorers.py b/AIMeiSheng/demucs/grids/_explorers.py
new file mode 100644
index 0000000..ec3a858
--- /dev/null
+++ b/AIMeiSheng/demucs/grids/_explorers.py
@@ -0,0 +1,64 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+from dora import Explorer
+import treetable as tt
+
+
+class MyExplorer(Explorer):
+ test_metrics = ['nsdr', 'sdr_med']
+
+ def get_grid_metrics(self):
+ """Return the metrics that should be displayed in the tracking table.
+ """
+ return [
+ tt.group("train", [
+ tt.leaf("epoch"),
+ tt.leaf("reco", ".3f"),
+ ], align=">"),
+ tt.group("valid", [
+ tt.leaf("penalty", ".1f"),
+ tt.leaf("ms", ".1f"),
+ tt.leaf("reco", ".2%"),
+ tt.leaf("breco", ".2%"),
+ tt.leaf("b_nsdr", ".2f"),
+ # tt.leaf("b_nsdr_drums", ".2f"),
+ # tt.leaf("b_nsdr_bass", ".2f"),
+ # tt.leaf("b_nsdr_other", ".2f"),
+ # tt.leaf("b_nsdr_vocals", ".2f"),
+ ], align=">"),
+ tt.group("test", [
+ tt.leaf(name, ".2f")
+ for name in self.test_metrics
+ ], align=">")
+ ]
+
+ def process_history(self, history):
+ train = {
+ 'epoch': len(history),
+ }
+ valid = {}
+ test = {}
+ best_v_main = float('inf')
+ breco = float('inf')
+ for metrics in history:
+ train.update(metrics['train'])
+ valid.update(metrics['valid'])
+ if 'main' in metrics['valid']:
+ best_v_main = min(best_v_main, metrics['valid']['main']['loss'])
+ valid['bmain'] = best_v_main
+ valid['breco'] = min(breco, metrics['valid']['reco'])
+ breco = valid['breco']
+ if (metrics['valid']['loss'] == metrics['valid']['best'] or
+ metrics['valid'].get('nsdr') == metrics['valid']['best']):
+ for k, v in metrics['valid'].items():
+ if k.startswith('reco_'):
+ valid['b_' + k[len('reco_'):]] = v
+ if k.startswith('nsdr'):
+ valid[f'b_{k}'] = v
+ if 'test' in metrics:
+ test.update(metrics['test'])
+ metrics = history[-1]
+ return {"train": train, "valid": valid, "test": test}
diff --git a/AIMeiSheng/demucs/grids/mdx.py b/AIMeiSheng/demucs/grids/mdx.py
new file mode 100644
index 0000000..62d447f
--- /dev/null
+++ b/AIMeiSheng/demucs/grids/mdx.py
@@ -0,0 +1,33 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""
+Main training for the Track A MDX models.
+"""
+
+from ._explorers import MyExplorer
+from ..train import main
+
+
+TRACK_A = ['0d19c1c6', '7ecf8ec1', 'c511e2ab', '7d865c68']
+
+
+@MyExplorer
+def explorer(launcher):
+ launcher.slurm_(
+ gpus=8,
+ time=3 * 24 * 60,
+ partition='learnlab')
+
+ # Reproduce results from MDX competition Track A
+ # This trains the first round of models. Once this is trained,
+ # you will need to schedule `mdx_refine`.
+ for sig in TRACK_A:
+ xp = main.get_xp_from_sig(sig)
+ parent = xp.cfg.continue_from
+ xp = main.get_xp_from_sig(parent)
+ launcher(xp.argv)
+ launcher(xp.argv, {'quant.diffq': 1e-4})
+ launcher(xp.argv, {'quant.diffq': 3e-4})
diff --git a/AIMeiSheng/demucs/grids/mdx_extra.py b/AIMeiSheng/demucs/grids/mdx_extra.py
new file mode 100644
index 0000000..b99a37b
--- /dev/null
+++ b/AIMeiSheng/demucs/grids/mdx_extra.py
@@ -0,0 +1,36 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""
+Main training for the Track A MDX models.
+"""
+
+from ._explorers import MyExplorer
+from ..train import main
+
+TRACK_B = ['e51eebcc', 'a1d90b5c', '5d2d6c55', 'cfa93e08']
+
+
+@MyExplorer
+def explorer(launcher):
+ launcher.slurm_(
+ gpus=8,
+ time=3 * 24 * 60,
+ partition='learnlab')
+
+ # Reproduce results from MDX competition Track A
+ # This trains the first round of models. Once this is trained,
+ # you will need to schedule `mdx_refine`.
+ for sig in TRACK_B:
+ while sig is not None:
+ xp = main.get_xp_from_sig(sig)
+ sig = xp.cfg.continue_from
+
+ for dset in ['extra44', 'extra_test']:
+ sub = launcher.bind(xp.argv, dset=dset)
+ sub()
+ if dset == 'extra_test':
+ sub({'quant.diffq': 1e-4})
+ sub({'quant.diffq': 3e-4})
diff --git a/AIMeiSheng/demucs/grids/mdx_refine.py b/AIMeiSheng/demucs/grids/mdx_refine.py
new file mode 100644
index 0000000..f62da1d
--- /dev/null
+++ b/AIMeiSheng/demucs/grids/mdx_refine.py
@@ -0,0 +1,34 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""
+Main training for the Track A MDX models.
+"""
+
+from ._explorers import MyExplorer
+from .mdx import TRACK_A
+from ..train import main
+
+
+@MyExplorer
+def explorer(launcher):
+ launcher.slurm_(
+ gpus=8,
+ time=3 * 24 * 60,
+ partition='learnlab')
+
+ # Reproduce results from MDX competition Track A
+ # WARNING: all the experiments in the `mdx` grid must have completed.
+ for sig in TRACK_A:
+ xp = main.get_xp_from_sig(sig)
+ launcher(xp.argv)
+ for diffq in [1e-4, 3e-4]:
+ xp_src = main.get_xp_from_sig(xp.cfg.continue_from)
+ q_argv = [f'quant.diffq={diffq}']
+ actual_src = main.get_xp(xp_src.argv + q_argv)
+ actual_src.link.load()
+ assert len(actual_src.link.history) == actual_src.cfg.epochs
+ argv = xp.argv + q_argv + [f'continue_from="{actual_src.sig}"']
+ launcher(argv)
diff --git a/AIMeiSheng/demucs/grids/mmi.py b/AIMeiSheng/demucs/grids/mmi.py
new file mode 100644
index 0000000..d75aa2b
--- /dev/null
+++ b/AIMeiSheng/demucs/grids/mmi.py
@@ -0,0 +1,69 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from ._explorers import MyExplorer
+from dora import Launcher
+
+
+@MyExplorer
+def explorer(launcher: Launcher):
+ launcher.slurm_(gpus=8, time=3 * 24 * 60, partition="devlab,learnlab,learnfair") # 3 days
+
+ sub = launcher.bind_(
+ {
+ "dset": "extra_mmi_goodclean",
+ "test.shifts": 0,
+ "model": "htdemucs",
+ "htdemucs.dconv_mode": 3,
+ "htdemucs.depth": 4,
+ "htdemucs.t_dropout": 0.02,
+ "htdemucs.t_layers": 5,
+ "max_batches": 800,
+ "ema.epoch": [0.9, 0.95],
+ "ema.batch": [0.9995, 0.9999],
+ "dset.segment": 10,
+ "batch_size": 32,
+ }
+ )
+ sub({"model": "hdemucs"})
+ sub({"model": "hdemucs", "dset": "extra44"})
+ sub({"model": "hdemucs", "dset": "musdb44"})
+
+ sparse = {
+ 'batch_size': 3 * 8,
+ 'augment.remix.group_size': 3,
+ 'htdemucs.t_auto_sparsity': True,
+ 'htdemucs.t_sparse_self_attn': True,
+ 'htdemucs.t_sparse_cross_attn': True,
+ 'htdemucs.t_sparsity': 0.9,
+ "htdemucs.t_layers": 7
+ }
+
+ with launcher.job_array():
+ for transf_layers in [5, 7]:
+ for bottom_channels in [0, 512]:
+ sub = launcher.bind({
+ "htdemucs.t_layers": transf_layers,
+ "htdemucs.bottom_channels": bottom_channels,
+ })
+ if bottom_channels == 0 and transf_layers == 5:
+ sub({"augment.remix.proba": 0.0})
+ sub({
+ "augment.repitch.proba": 0.0,
+ # when doing repitching, we trim the outut to align on the
+ # highest change of BPM. When removing repitching,
+ # we simulate it here to ensure the training context is the same.
+ # Another second is lost for all experiments due to the random
+ # shift augmentation.
+ "dset.segment": 10 * 0.88})
+ elif bottom_channels == 512 and transf_layers == 5:
+ sub(dset="musdb44")
+ sub(dset="extra44")
+ # Sparse kernel XP, currently not released as kernels are still experimental.
+ sub(sparse, {'dset.segment': 15, "htdemucs.t_layers": 7})
+
+ for duration in [5, 10, 15]:
+ sub({"dset.segment": duration})
diff --git a/AIMeiSheng/demucs/grids/mmi_ft.py b/AIMeiSheng/demucs/grids/mmi_ft.py
new file mode 100644
index 0000000..73e488b
--- /dev/null
+++ b/AIMeiSheng/demucs/grids/mmi_ft.py
@@ -0,0 +1,55 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from ._explorers import MyExplorer
+from dora import Launcher
+from demucs import train
+
+
+def get_sub(launcher, sig):
+ xp = train.main.get_xp_from_sig(sig)
+ sub = launcher.bind(xp.argv)
+ sub()
+ sub.bind_({
+ 'continue_from': sig,
+ 'continue_best': True})
+ return sub
+
+
+@MyExplorer
+def explorer(launcher: Launcher):
+ launcher.slurm_(gpus=4, time=3 * 24 * 60, partition="devlab,learnlab,learnfair") # 3 days
+ ft = {
+ 'optim.lr': 1e-4,
+ 'augment.remix.proba': 0,
+ 'augment.scale.proba': 0,
+ 'augment.shift_same': True,
+ 'htdemucs.t_weight_decay': 0.05,
+ 'batch_size': 8,
+ 'optim.clip_grad': 5,
+ 'optim.optim': 'adamw',
+ 'epochs': 50,
+ 'dset.wav2_valid': True,
+ 'ema.epoch': [], # let's make valid a bit faster
+ }
+ with launcher.job_array():
+ for sig in ['2899e11a']:
+ sub = get_sub(launcher, sig)
+ sub.bind_(ft)
+ for segment in [15, 18]:
+ for source in range(4):
+ w = [0] * 4
+ w[source] = 1
+ sub({'weights': w, 'dset.segment': segment})
+
+ for sig in ['955717e8']:
+ sub = get_sub(launcher, sig)
+ sub.bind_(ft)
+ for segment in [10, 15]:
+ for source in range(4):
+ w = [0] * 4
+ w[source] = 1
+ sub({'weights': w, 'dset.segment': segment})
diff --git a/AIMeiSheng/demucs/grids/repro.py b/AIMeiSheng/demucs/grids/repro.py
new file mode 100644
index 0000000..21d33fc
--- /dev/null
+++ b/AIMeiSheng/demucs/grids/repro.py
@@ -0,0 +1,50 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""
+Easier training for reproducibility
+"""
+
+from ._explorers import MyExplorer
+
+
+@MyExplorer
+def explorer(launcher):
+ launcher.slurm_(
+ gpus=8,
+ time=3 * 24 * 60,
+ partition='devlab,learnlab')
+
+ launcher.bind_({'ema.epoch': [0.9, 0.95]})
+ launcher.bind_({'ema.batch': [0.9995, 0.9999]})
+ launcher.bind_({'epochs': 600})
+
+ base = {'model': 'demucs', 'demucs.dconv_mode': 0, 'demucs.gelu': False,
+ 'demucs.lstm_layers': 2}
+ newt = {'model': 'demucs', 'demucs.normalize': True}
+ hdem = {'model': 'hdemucs'}
+ svd = {'svd.penalty': 1e-5, 'svd': 'base2'}
+
+ with launcher.job_array():
+ for model in [base, newt, hdem]:
+ sub = launcher.bind(model)
+ if model is base:
+ # Training the v2 Demucs on MusDB HQ
+ sub(epochs=360)
+ continue
+
+ # those two will be used in the repro_mdx_a bag of models.
+ sub(svd)
+ sub(svd, seed=43)
+ if model == newt:
+ # Ablation study
+ sub()
+ abl = sub.bind(svd)
+ abl({'ema.epoch': [], 'ema.batch': []})
+ abl({'demucs.dconv_lstm': 10})
+ abl({'demucs.dconv_attn': 10})
+ abl({'demucs.dconv_attn': 10, 'demucs.dconv_lstm': 10, 'demucs.lstm_layers': 2})
+ abl({'demucs.dconv_mode': 0})
+ abl({'demucs.gelu': False})
diff --git a/AIMeiSheng/demucs/grids/repro_ft.py b/AIMeiSheng/demucs/grids/repro_ft.py
new file mode 100644
index 0000000..7bb4ee8
--- /dev/null
+++ b/AIMeiSheng/demucs/grids/repro_ft.py
@@ -0,0 +1,46 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""
+Fine tuning experiments
+"""
+
+from ._explorers import MyExplorer
+from ..train import main
+
+
+@MyExplorer
+def explorer(launcher):
+ launcher.slurm_(
+ gpus=8,
+ time=300,
+ partition='devlab,learnlab')
+
+ # Mus
+ launcher.slurm_(constraint='volta32gb')
+
+ grid = "repro"
+ folder = main.dora.dir / "grids" / grid
+
+ for sig in folder.iterdir():
+ if not sig.is_symlink():
+ continue
+ xp = main.get_xp_from_sig(sig)
+ xp.link.load()
+ if len(xp.link.history) != xp.cfg.epochs:
+ continue
+ sub = launcher.bind(xp.argv, [f'continue_from="{xp.sig}"'])
+ sub.bind_({'ema.epoch': [0.9, 0.95], 'ema.batch': [0.9995, 0.9999]})
+ sub.bind_({'test.every': 1, 'test.sdr': True, 'epochs': 4})
+ sub.bind_({'dset.segment': 28, 'dset.shift': 2})
+ sub.bind_({'batch_size': 32})
+ auto = {'dset': 'auto_mus'}
+ auto.update({'augment.remix.proba': 0, 'augment.scale.proba': 0,
+ 'augment.shift_same': True})
+ sub.bind_(auto)
+ sub.bind_({'batch_size': 16})
+ sub.bind_({'optim.lr': 1e-4})
+ sub.bind_({'model_segment': 44})
+ sub()
diff --git a/AIMeiSheng/demucs/grids/sdx23.py b/AIMeiSheng/demucs/grids/sdx23.py
new file mode 100644
index 0000000..3bdb419
--- /dev/null
+++ b/AIMeiSheng/demucs/grids/sdx23.py
@@ -0,0 +1,19 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from ._explorers import MyExplorer
+from dora import Launcher
+
+
+@MyExplorer
+def explorer(launcher: Launcher):
+ launcher.slurm_(gpus=8, time=3 * 24 * 60, partition="speechgpt,learnfair",
+ mem_per_gpu=None, constraint='')
+ launcher.bind_({"dset.use_musdb": False})
+
+ with launcher.job_array():
+ launcher(dset='sdx23_bleeding')
+ launcher(dset='sdx23_labelnoise')
diff --git a/AIMeiSheng/demucs/hdemucs.py b/AIMeiSheng/demucs/hdemucs.py
new file mode 100644
index 0000000..711d471
--- /dev/null
+++ b/AIMeiSheng/demucs/hdemucs.py
@@ -0,0 +1,794 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""
+This code contains the spectrogram and Hybrid version of Demucs.
+"""
+from copy import deepcopy
+import math
+import typing as tp
+
+from openunmix.filtering import wiener
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from .demucs import DConv, rescale_module
+from .states import capture_init
+from .spec import spectro, ispectro
+
+
+def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.):
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
+ If this is the case, we insert extra 0 padding to the right before the reflection happen."""
+ x0 = x
+ length = x.shape[-1]
+ padding_left, padding_right = paddings
+ if mode == 'reflect':
+ max_pad = max(padding_left, padding_right)
+ if length <= max_pad:
+ extra_pad = max_pad - length + 1
+ extra_pad_right = min(padding_right, extra_pad)
+ extra_pad_left = extra_pad - extra_pad_right
+ paddings = (padding_left - extra_pad_left, padding_right - extra_pad_right)
+ x = F.pad(x, (extra_pad_left, extra_pad_right))
+ out = F.pad(x, paddings, mode, value)
+ assert out.shape[-1] == length + padding_left + padding_right
+ assert (out[..., padding_left: padding_left + length] == x0).all()
+ return out
+
+
+class ScaledEmbedding(nn.Module):
+ """
+ Boost learning rate for embeddings (with `scale`).
+ Also, can make embeddings continuous with `smooth`.
+ """
+ def __init__(self, num_embeddings: int, embedding_dim: int,
+ scale: float = 10., smooth=False):
+ super().__init__()
+ self.embedding = nn.Embedding(num_embeddings, embedding_dim)
+ if smooth:
+ weight = torch.cumsum(self.embedding.weight.data, dim=0)
+ # when summing gaussian, overscale raises as sqrt(n), so we nornalize by that.
+ weight = weight / torch.arange(1, num_embeddings + 1).to(weight).sqrt()[:, None]
+ self.embedding.weight.data[:] = weight
+ self.embedding.weight.data /= scale
+ self.scale = scale
+
+ @property
+ def weight(self):
+ return self.embedding.weight * self.scale
+
+ def forward(self, x):
+ out = self.embedding(x) * self.scale
+ return out
+
+
+class HEncLayer(nn.Module):
+ def __init__(self, chin, chout, kernel_size=8, stride=4, norm_groups=1, empty=False,
+ freq=True, dconv=True, norm=True, context=0, dconv_kw={}, pad=True,
+ rewrite=True):
+ """Encoder layer. This used both by the time and the frequency branch.
+
+ Args:
+ chin: number of input channels.
+ chout: number of output channels.
+ norm_groups: number of groups for group norm.
+ empty: used to make a layer with just the first conv. this is used
+ before merging the time and freq. branches.
+ freq: this is acting on frequencies.
+ dconv: insert DConv residual branches.
+ norm: use GroupNorm.
+ context: context size for the 1x1 conv.
+ dconv_kw: list of kwargs for the DConv class.
+ pad: pad the input. Padding is done so that the output size is
+ always the input size / stride.
+ rewrite: add 1x1 conv at the end of the layer.
+ """
+ super().__init__()
+ norm_fn = lambda d: nn.Identity() # noqa
+ if norm:
+ norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
+ if pad:
+ pad = kernel_size // 4
+ else:
+ pad = 0
+ klass = nn.Conv1d
+ self.freq = freq
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.empty = empty
+ self.norm = norm
+ self.pad = pad
+ if freq:
+ kernel_size = [kernel_size, 1]
+ stride = [stride, 1]
+ pad = [pad, 0]
+ klass = nn.Conv2d
+ self.conv = klass(chin, chout, kernel_size, stride, pad)
+ if self.empty:
+ return
+ self.norm1 = norm_fn(chout)
+ self.rewrite = None
+ if rewrite:
+ self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context)
+ self.norm2 = norm_fn(2 * chout)
+
+ self.dconv = None
+ if dconv:
+ self.dconv = DConv(chout, **dconv_kw)
+
+ def forward(self, x, inject=None):
+ """
+ `inject` is used to inject the result from the time branch into the frequency branch,
+ when both have the same stride.
+ """
+ if not self.freq and x.dim() == 4:
+ B, C, Fr, T = x.shape
+ x = x.view(B, -1, T)
+
+ if not self.freq:
+ le = x.shape[-1]
+ if not le % self.stride == 0:
+ x = F.pad(x, (0, self.stride - (le % self.stride)))
+ y = self.conv(x)
+ if self.empty:
+ return y
+ if inject is not None:
+ assert inject.shape[-1] == y.shape[-1], (inject.shape, y.shape)
+ if inject.dim() == 3 and y.dim() == 4:
+ inject = inject[:, :, None]
+ y = y + inject
+ y = F.gelu(self.norm1(y))
+ if self.dconv:
+ if self.freq:
+ B, C, Fr, T = y.shape
+ y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
+ y = self.dconv(y)
+ if self.freq:
+ y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
+ if self.rewrite:
+ z = self.norm2(self.rewrite(y))
+ z = F.glu(z, dim=1)
+ else:
+ z = y
+ return z
+
+
+class MultiWrap(nn.Module):
+ """
+ Takes one layer and replicate it N times. each replica will act
+ on a frequency band. All is done so that if the N replica have the same weights,
+ then this is exactly equivalent to applying the original module on all frequencies.
+
+ This is a bit over-engineered to avoid edge artifacts when splitting
+ the frequency bands, but it is possible the naive implementation would work as well...
+ """
+ def __init__(self, layer, split_ratios):
+ """
+ Args:
+ layer: module to clone, must be either HEncLayer or HDecLayer.
+ split_ratios: list of float indicating which ratio to keep for each band.
+ """
+ super().__init__()
+ self.split_ratios = split_ratios
+ self.layers = nn.ModuleList()
+ self.conv = isinstance(layer, HEncLayer)
+ assert not layer.norm
+ assert layer.freq
+ assert layer.pad
+ if not self.conv:
+ assert not layer.context_freq
+ for k in range(len(split_ratios) + 1):
+ lay = deepcopy(layer)
+ if self.conv:
+ lay.conv.padding = (0, 0)
+ else:
+ lay.pad = False
+ for m in lay.modules():
+ if hasattr(m, 'reset_parameters'):
+ m.reset_parameters()
+ self.layers.append(lay)
+
+ def forward(self, x, skip=None, length=None):
+ B, C, Fr, T = x.shape
+
+ ratios = list(self.split_ratios) + [1]
+ start = 0
+ outs = []
+ for ratio, layer in zip(ratios, self.layers):
+ if self.conv:
+ pad = layer.kernel_size // 4
+ if ratio == 1:
+ limit = Fr
+ frames = -1
+ else:
+ limit = int(round(Fr * ratio))
+ le = limit - start
+ if start == 0:
+ le += pad
+ frames = round((le - layer.kernel_size) / layer.stride + 1)
+ limit = start + (frames - 1) * layer.stride + layer.kernel_size
+ if start == 0:
+ limit -= pad
+ assert limit - start > 0, (limit, start)
+ assert limit <= Fr, (limit, Fr)
+ y = x[:, :, start:limit, :]
+ if start == 0:
+ y = F.pad(y, (0, 0, pad, 0))
+ if ratio == 1:
+ y = F.pad(y, (0, 0, 0, pad))
+ outs.append(layer(y))
+ start = limit - layer.kernel_size + layer.stride
+ else:
+ if ratio == 1:
+ limit = Fr
+ else:
+ limit = int(round(Fr * ratio))
+ last = layer.last
+ layer.last = True
+
+ y = x[:, :, start:limit]
+ s = skip[:, :, start:limit]
+ out, _ = layer(y, s, None)
+ if outs:
+ outs[-1][:, :, -layer.stride:] += (
+ out[:, :, :layer.stride] - layer.conv_tr.bias.view(1, -1, 1, 1))
+ out = out[:, :, layer.stride:]
+ if ratio == 1:
+ out = out[:, :, :-layer.stride // 2, :]
+ if start == 0:
+ out = out[:, :, layer.stride // 2:, :]
+ outs.append(out)
+ layer.last = last
+ start = limit
+ out = torch.cat(outs, dim=2)
+ if not self.conv and not last:
+ out = F.gelu(out)
+ if self.conv:
+ return out
+ else:
+ return out, None
+
+
+class HDecLayer(nn.Module):
+ def __init__(self, chin, chout, last=False, kernel_size=8, stride=4, norm_groups=1, empty=False,
+ freq=True, dconv=True, norm=True, context=1, dconv_kw={}, pad=True,
+ context_freq=True, rewrite=True):
+ """
+ Same as HEncLayer but for decoder. See `HEncLayer` for documentation.
+ """
+ super().__init__()
+ norm_fn = lambda d: nn.Identity() # noqa
+ if norm:
+ norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
+ if pad:
+ pad = kernel_size // 4
+ else:
+ pad = 0
+ self.pad = pad
+ self.last = last
+ self.freq = freq
+ self.chin = chin
+ self.empty = empty
+ self.stride = stride
+ self.kernel_size = kernel_size
+ self.norm = norm
+ self.context_freq = context_freq
+ klass = nn.Conv1d
+ klass_tr = nn.ConvTranspose1d
+ if freq:
+ kernel_size = [kernel_size, 1]
+ stride = [stride, 1]
+ klass = nn.Conv2d
+ klass_tr = nn.ConvTranspose2d
+ self.conv_tr = klass_tr(chin, chout, kernel_size, stride)
+ self.norm2 = norm_fn(chout)
+ if self.empty:
+ return
+ self.rewrite = None
+ if rewrite:
+ if context_freq:
+ self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context)
+ else:
+ self.rewrite = klass(chin, 2 * chin, [1, 1 + 2 * context], 1,
+ [0, context])
+ self.norm1 = norm_fn(2 * chin)
+
+ self.dconv = None
+ if dconv:
+ self.dconv = DConv(chin, **dconv_kw)
+
+ def forward(self, x, skip, length):
+ if self.freq and x.dim() == 3:
+ B, C, T = x.shape
+ x = x.view(B, self.chin, -1, T)
+
+ if not self.empty:
+ x = x + skip
+
+ if self.rewrite:
+ y = F.glu(self.norm1(self.rewrite(x)), dim=1)
+ else:
+ y = x
+ if self.dconv:
+ if self.freq:
+ B, C, Fr, T = y.shape
+ y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
+ y = self.dconv(y)
+ if self.freq:
+ y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
+ else:
+ y = x
+ assert skip is None
+ z = self.norm2(self.conv_tr(y))
+ if self.freq:
+ if self.pad:
+ z = z[..., self.pad:-self.pad, :]
+ else:
+ z = z[..., self.pad:self.pad + length]
+ assert z.shape[-1] == length, (z.shape[-1], length)
+ if not self.last:
+ z = F.gelu(z)
+ return z, y
+
+
+class HDemucs(nn.Module):
+ """
+ Spectrogram and hybrid Demucs model.
+ The spectrogram model has the same structure as Demucs, except the first few layers are over the
+ frequency axis, until there is only 1 frequency, and then it moves to time convolutions.
+ Frequency layers can still access information across time steps thanks to the DConv residual.
+
+ Hybrid model have a parallel time branch. At some layer, the time branch has the same stride
+ as the frequency branch and then the two are combined. The opposite happens in the decoder.
+
+ Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]),
+ or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on
+ Open Unmix implementation [Stoter et al. 2019].
+
+ The loss is always on the temporal domain, by backpropagating through the above
+ output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks
+ a bit Wiener filtering, as doing more iteration at test time will change the spectrogram
+ contribution, without changing the one from the waveform, which will lead to worse performance.
+ I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve.
+ CaC on the other hand provides similar performance for hybrid, and works naturally with
+ hybrid models.
+
+ This model also uses frequency embeddings are used to improve efficiency on convolutions
+ over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf).
+
+ Unlike classic Demucs, there is no resampling here, and normalization is always applied.
+ """
+ @capture_init
+ def __init__(self,
+ sources,
+ # Channels
+ audio_channels=2,
+ channels=48,
+ channels_time=None,
+ growth=2,
+ # STFT
+ nfft=4096,
+ wiener_iters=0,
+ end_iters=0,
+ wiener_residual=False,
+ cac=True,
+ # Main structure
+ depth=6,
+ rewrite=True,
+ hybrid=True,
+ hybrid_old=False,
+ # Frequency branch
+ multi_freqs=None,
+ multi_freqs_depth=2,
+ freq_emb=0.2,
+ emb_scale=10,
+ emb_smooth=True,
+ # Convolutions
+ kernel_size=8,
+ time_stride=2,
+ stride=4,
+ context=1,
+ context_enc=0,
+ # Normalization
+ norm_starts=4,
+ norm_groups=4,
+ # DConv residual branch
+ dconv_mode=1,
+ dconv_depth=2,
+ dconv_comp=4,
+ dconv_attn=4,
+ dconv_lstm=4,
+ dconv_init=1e-4,
+ # Weight init
+ rescale=0.1,
+ # Metadata
+ samplerate=44100,
+ segment=4 * 10):
+ """
+ Args:
+ sources (list[str]): list of source names.
+ audio_channels (int): input/output audio channels.
+ channels (int): initial number of hidden channels.
+ channels_time: if not None, use a different `channels` value for the time branch.
+ growth: increase the number of hidden channels by this factor at each layer.
+ nfft: number of fft bins. Note that changing this require careful computation of
+ various shape parameters and will not work out of the box for hybrid models.
+ wiener_iters: when using Wiener filtering, number of iterations at test time.
+ end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`.
+ wiener_residual: add residual source before wiener filtering.
+ cac: uses complex as channels, i.e. complex numbers are 2 channels each
+ in input and output. no further processing is done before ISTFT.
+ depth (int): number of layers in the encoder and in the decoder.
+ rewrite (bool): add 1x1 convolution to each layer.
+ hybrid (bool): make a hybrid time/frequency domain, otherwise frequency only.
+ hybrid_old: some models trained for MDX had a padding bug. This replicates
+ this bug to avoid retraining them.
+ multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`.
+ multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost
+ layers will be wrapped.
+ freq_emb: add frequency embedding after the first frequency layer if > 0,
+ the actual value controls the weight of the embedding.
+ emb_scale: equivalent to scaling the embedding learning rate
+ emb_smooth: initialize the embedding with a smooth one (with respect to frequencies).
+ kernel_size: kernel_size for encoder and decoder layers.
+ stride: stride for encoder and decoder layers.
+ time_stride: stride for the final time layer, after the merge.
+ context: context for 1x1 conv in the decoder.
+ context_enc: context for 1x1 conv in the encoder.
+ norm_starts: layer at which group norm starts being used.
+ decoder layers are numbered in reverse order.
+ norm_groups: number of groups for group norm.
+ dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
+ dconv_depth: depth of residual DConv branch.
+ dconv_comp: compression of DConv branch.
+ dconv_attn: adds attention layers in DConv branch starting at this layer.
+ dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
+ dconv_init: initial scale for the DConv branch LayerScale.
+ rescale: weight recaling trick
+
+ """
+ super().__init__()
+ self.cac = cac
+ self.wiener_residual = wiener_residual
+ self.audio_channels = audio_channels
+ self.sources = sources
+ self.kernel_size = kernel_size
+ self.context = context
+ self.stride = stride
+ self.depth = depth
+ self.channels = channels
+ self.samplerate = samplerate
+ self.segment = segment
+
+ self.nfft = nfft
+ self.hop_length = nfft // 4
+ self.wiener_iters = wiener_iters
+ self.end_iters = end_iters
+ self.freq_emb = None
+ self.hybrid = hybrid
+ self.hybrid_old = hybrid_old
+ if hybrid_old:
+ assert hybrid, "hybrid_old must come with hybrid=True"
+ if hybrid:
+ assert wiener_iters == end_iters
+
+ self.encoder = nn.ModuleList()
+ self.decoder = nn.ModuleList()
+
+ if hybrid:
+ self.tencoder = nn.ModuleList()
+ self.tdecoder = nn.ModuleList()
+
+ chin = audio_channels
+ chin_z = chin # number of channels for the freq branch
+ if self.cac:
+ chin_z *= 2
+ chout = channels_time or channels
+ chout_z = channels
+ freqs = nfft // 2
+
+ for index in range(depth):
+ lstm = index >= dconv_lstm
+ attn = index >= dconv_attn
+ norm = index >= norm_starts
+ freq = freqs > 1
+ stri = stride
+ ker = kernel_size
+ if not freq:
+ assert freqs == 1
+ ker = time_stride * 2
+ stri = time_stride
+
+ pad = True
+ last_freq = False
+ if freq and freqs <= kernel_size:
+ ker = freqs
+ pad = False
+ last_freq = True
+
+ kw = {
+ 'kernel_size': ker,
+ 'stride': stri,
+ 'freq': freq,
+ 'pad': pad,
+ 'norm': norm,
+ 'rewrite': rewrite,
+ 'norm_groups': norm_groups,
+ 'dconv_kw': {
+ 'lstm': lstm,
+ 'attn': attn,
+ 'depth': dconv_depth,
+ 'compress': dconv_comp,
+ 'init': dconv_init,
+ 'gelu': True,
+ }
+ }
+ kwt = dict(kw)
+ kwt['freq'] = 0
+ kwt['kernel_size'] = kernel_size
+ kwt['stride'] = stride
+ kwt['pad'] = True
+ kw_dec = dict(kw)
+ multi = False
+ if multi_freqs and index < multi_freqs_depth:
+ multi = True
+ kw_dec['context_freq'] = False
+
+ if last_freq:
+ chout_z = max(chout, chout_z)
+ chout = chout_z
+
+ enc = HEncLayer(chin_z, chout_z,
+ dconv=dconv_mode & 1, context=context_enc, **kw)
+ if hybrid and freq:
+ tenc = HEncLayer(chin, chout, dconv=dconv_mode & 1, context=context_enc,
+ empty=last_freq, **kwt)
+ self.tencoder.append(tenc)
+
+ if multi:
+ enc = MultiWrap(enc, multi_freqs)
+ self.encoder.append(enc)
+ if index == 0:
+ chin = self.audio_channels * len(self.sources)
+ chin_z = chin
+ if self.cac:
+ chin_z *= 2
+ dec = HDecLayer(chout_z, chin_z, dconv=dconv_mode & 2,
+ last=index == 0, context=context, **kw_dec)
+ if multi:
+ dec = MultiWrap(dec, multi_freqs)
+ if hybrid and freq:
+ tdec = HDecLayer(chout, chin, dconv=dconv_mode & 2, empty=last_freq,
+ last=index == 0, context=context, **kwt)
+ self.tdecoder.insert(0, tdec)
+ self.decoder.insert(0, dec)
+
+ chin = chout
+ chin_z = chout_z
+ chout = int(growth * chout)
+ chout_z = int(growth * chout_z)
+ if freq:
+ if freqs <= kernel_size:
+ freqs = 1
+ else:
+ freqs //= stride
+ if index == 0 and freq_emb:
+ self.freq_emb = ScaledEmbedding(
+ freqs, chin_z, smooth=emb_smooth, scale=emb_scale)
+ self.freq_emb_scale = freq_emb
+
+ if rescale:
+ rescale_module(self, reference=rescale)
+
+ def _spec(self, x):
+ hl = self.hop_length
+ nfft = self.nfft
+ x0 = x # noqa
+
+ if self.hybrid:
+ # We re-pad the signal in order to keep the property
+ # that the size of the output is exactly the size of the input
+ # divided by the stride (here hop_length), when divisible.
+ # This is achieved by padding by 1/4th of the kernel size (here nfft).
+ # which is not supported by torch.stft.
+ # Having all convolution operations follow this convention allow to easily
+ # align the time and frequency branches later on.
+ assert hl == nfft // 4
+ le = int(math.ceil(x.shape[-1] / hl))
+ pad = hl // 2 * 3
+ if not self.hybrid_old:
+ x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode='reflect')
+ else:
+ x = pad1d(x, (pad, pad + le * hl - x.shape[-1]))
+
+ z = spectro(x, nfft, hl)[..., :-1, :]
+ if self.hybrid:
+ assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
+ z = z[..., 2:2+le]
+ return z
+
+ def _ispec(self, z, length=None, scale=0):
+ hl = self.hop_length // (4 ** scale)
+ z = F.pad(z, (0, 0, 0, 1))
+ if self.hybrid:
+ z = F.pad(z, (2, 2))
+ pad = hl // 2 * 3
+ if not self.hybrid_old:
+ le = hl * int(math.ceil(length / hl)) + 2 * pad
+ else:
+ le = hl * int(math.ceil(length / hl))
+ x = ispectro(z, hl, length=le)
+ if not self.hybrid_old:
+ x = x[..., pad:pad + length]
+ else:
+ x = x[..., :length]
+ else:
+ x = ispectro(z, hl, length)
+ return x
+
+ def _magnitude(self, z):
+ # return the magnitude of the spectrogram, except when cac is True,
+ # in which case we just move the complex dimension to the channel one.
+ if self.cac:
+ B, C, Fr, T = z.shape
+ m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
+ m = m.reshape(B, C * 2, Fr, T)
+ else:
+ m = z.abs()
+ return m
+
+ def _mask(self, z, m):
+ # Apply masking given the mixture spectrogram `z` and the estimated mask `m`.
+ # If `cac` is True, `m` is actually a full spectrogram and `z` is ignored.
+ niters = self.wiener_iters
+ if self.cac:
+ B, S, C, Fr, T = m.shape
+ out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
+ out = torch.view_as_complex(out.contiguous())
+ return out
+ if self.training:
+ niters = self.end_iters
+ if niters < 0:
+ z = z[:, None]
+ return z / (1e-8 + z.abs()) * m
+ else:
+ return self._wiener(m, z, niters)
+
+ def _wiener(self, mag_out, mix_stft, niters):
+ # apply wiener filtering from OpenUnmix.
+ init = mix_stft.dtype
+ wiener_win_len = 300
+ residual = self.wiener_residual
+
+ B, S, C, Fq, T = mag_out.shape
+ mag_out = mag_out.permute(0, 4, 3, 2, 1)
+ mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1))
+
+ outs = []
+ for sample in range(B):
+ pos = 0
+ out = []
+ for pos in range(0, T, wiener_win_len):
+ frame = slice(pos, pos + wiener_win_len)
+ z_out = wiener(
+ mag_out[sample, frame], mix_stft[sample, frame], niters,
+ residual=residual)
+ out.append(z_out.transpose(-1, -2))
+ outs.append(torch.cat(out, dim=0))
+ out = torch.view_as_complex(torch.stack(outs, 0))
+ out = out.permute(0, 4, 3, 2, 1).contiguous()
+ if residual:
+ out = out[:, :-1]
+ assert list(out.shape) == [B, S, C, Fq, T]
+ return out.to(init)
+
+ def forward(self, mix):
+ x = mix
+ length = x.shape[-1]
+
+ z = self._spec(mix)
+ mag = self._magnitude(z).to(mix.device)
+ x = mag
+
+ B, C, Fq, T = x.shape
+
+ # unlike previous Demucs, we always normalize because it is easier.
+ mean = x.mean(dim=(1, 2, 3), keepdim=True)
+ std = x.std(dim=(1, 2, 3), keepdim=True)
+ x = (x - mean) / (1e-5 + std)
+ # x will be the freq. branch input.
+
+ if self.hybrid:
+ # Prepare the time branch input.
+ xt = mix
+ meant = xt.mean(dim=(1, 2), keepdim=True)
+ stdt = xt.std(dim=(1, 2), keepdim=True)
+ xt = (xt - meant) / (1e-5 + stdt)
+
+ # okay, this is a giant mess I know...
+ saved = [] # skip connections, freq.
+ saved_t = [] # skip connections, time.
+ lengths = [] # saved lengths to properly remove padding, freq branch.
+ lengths_t = [] # saved lengths for time branch.
+ for idx, encode in enumerate(self.encoder):
+ lengths.append(x.shape[-1])
+ inject = None
+ if self.hybrid and idx < len(self.tencoder):
+ # we have not yet merged branches.
+ lengths_t.append(xt.shape[-1])
+ tenc = self.tencoder[idx]
+ xt = tenc(xt)
+ if not tenc.empty:
+ # save for skip connection
+ saved_t.append(xt)
+ else:
+ # tenc contains just the first conv., so that now time and freq.
+ # branches have the same shape and can be merged.
+ inject = xt
+ x = encode(x, inject)
+ if idx == 0 and self.freq_emb is not None:
+ # add frequency embedding to allow for non equivariant convolutions
+ # over the frequency axis.
+ frs = torch.arange(x.shape[-2], device=x.device)
+ emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
+ x = x + self.freq_emb_scale * emb
+
+ saved.append(x)
+
+ x = torch.zeros_like(x)
+ if self.hybrid:
+ xt = torch.zeros_like(x)
+ # initialize everything to zero (signal will go through u-net skips).
+
+ for idx, decode in enumerate(self.decoder):
+ skip = saved.pop(-1)
+ x, pre = decode(x, skip, lengths.pop(-1))
+ # `pre` contains the output just before final transposed convolution,
+ # which is used when the freq. and time branch separate.
+
+ if self.hybrid:
+ offset = self.depth - len(self.tdecoder)
+ if self.hybrid and idx >= offset:
+ tdec = self.tdecoder[idx - offset]
+ length_t = lengths_t.pop(-1)
+ if tdec.empty:
+ assert pre.shape[2] == 1, pre.shape
+ pre = pre[:, :, 0]
+ xt, _ = tdec(pre, None, length_t)
+ else:
+ skip = saved_t.pop(-1)
+ xt, _ = tdec(xt, skip, length_t)
+
+ # Let's make sure we used all stored skip connections.
+ assert len(saved) == 0
+ assert len(lengths_t) == 0
+ assert len(saved_t) == 0
+
+ S = len(self.sources)
+ x = x.view(B, S, -1, Fq, T)
+ x = x * std[:, None] + mean[:, None]
+
+ # to cpu as mps doesnt support complex numbers
+ # demucs issue #435 ##432
+ # NOTE: in this case z already is on cpu
+ # TODO: remove this when mps supports complex numbers
+ x_is_mps = x.device.type == "mps"
+ if x_is_mps:
+ x = x.cpu()
+
+ zout = self._mask(z, x)
+ x = self._ispec(zout, length)
+
+ # back to mps device
+ if x_is_mps:
+ x = x.to('mps')
+
+ if self.hybrid:
+ xt = xt.view(B, S, -1, length)
+ xt = xt * stdt[:, None] + meant[:, None]
+ x = xt + x
+ return x
diff --git a/AIMeiSheng/demucs/htdemucs.py b/AIMeiSheng/demucs/htdemucs.py
new file mode 100644
index 0000000..5d2eaaa
--- /dev/null
+++ b/AIMeiSheng/demucs/htdemucs.py
@@ -0,0 +1,660 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# First author is Simon Rouard.
+"""
+This code contains the spectrogram and Hybrid version of Demucs.
+"""
+import math
+
+from openunmix.filtering import wiener
+import torch
+from torch import nn
+from torch.nn import functional as F
+from fractions import Fraction
+from einops import rearrange
+
+from .transformer import CrossTransformerEncoder
+
+from .demucs import rescale_module
+from .states import capture_init
+from .spec import spectro, ispectro
+from .hdemucs import pad1d, ScaledEmbedding, HEncLayer, MultiWrap, HDecLayer
+
+
+class HTDemucs(nn.Module):
+ """
+ Spectrogram and hybrid Demucs model.
+ The spectrogram model has the same structure as Demucs, except the first few layers are over the
+ frequency axis, until there is only 1 frequency, and then it moves to time convolutions.
+ Frequency layers can still access information across time steps thanks to the DConv residual.
+
+ Hybrid model have a parallel time branch. At some layer, the time branch has the same stride
+ as the frequency branch and then the two are combined. The opposite happens in the decoder.
+
+ Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]),
+ or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on
+ Open Unmix implementation [Stoter et al. 2019].
+
+ The loss is always on the temporal domain, by backpropagating through the above
+ output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks
+ a bit Wiener filtering, as doing more iteration at test time will change the spectrogram
+ contribution, without changing the one from the waveform, which will lead to worse performance.
+ I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve.
+ CaC on the other hand provides similar performance for hybrid, and works naturally with
+ hybrid models.
+
+ This model also uses frequency embeddings are used to improve efficiency on convolutions
+ over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf).
+
+ Unlike classic Demucs, there is no resampling here, and normalization is always applied.
+ """
+
+ @capture_init
+ def __init__(
+ self,
+ sources,
+ # Channels
+ audio_channels=2,
+ channels=48,
+ channels_time=None,
+ growth=2,
+ # STFT
+ nfft=4096,
+ wiener_iters=0,
+ end_iters=0,
+ wiener_residual=False,
+ cac=True,
+ # Main structure
+ depth=4,
+ rewrite=True,
+ # Frequency branch
+ multi_freqs=None,
+ multi_freqs_depth=3,
+ freq_emb=0.2,
+ emb_scale=10,
+ emb_smooth=True,
+ # Convolutions
+ kernel_size=8,
+ time_stride=2,
+ stride=4,
+ context=1,
+ context_enc=0,
+ # Normalization
+ norm_starts=4,
+ norm_groups=4,
+ # DConv residual branch
+ dconv_mode=1,
+ dconv_depth=2,
+ dconv_comp=8,
+ dconv_init=1e-3,
+ # Before the Transformer
+ bottom_channels=0,
+ # Transformer
+ t_layers=5,
+ t_emb="sin",
+ t_hidden_scale=4.0,
+ t_heads=8,
+ t_dropout=0.0,
+ t_max_positions=10000,
+ t_norm_in=True,
+ t_norm_in_group=False,
+ t_group_norm=False,
+ t_norm_first=True,
+ t_norm_out=True,
+ t_max_period=10000.0,
+ t_weight_decay=0.0,
+ t_lr=None,
+ t_layer_scale=True,
+ t_gelu=True,
+ t_weight_pos_embed=1.0,
+ t_sin_random_shift=0,
+ t_cape_mean_normalize=True,
+ t_cape_augment=True,
+ t_cape_glob_loc_scale=[5000.0, 1.0, 1.4],
+ t_sparse_self_attn=False,
+ t_sparse_cross_attn=False,
+ t_mask_type="diag",
+ t_mask_random_seed=42,
+ t_sparse_attn_window=500,
+ t_global_window=100,
+ t_sparsity=0.95,
+ t_auto_sparsity=False,
+ # ------ Particuliar parameters
+ t_cross_first=False,
+ # Weight init
+ rescale=0.1,
+ # Metadata
+ samplerate=44100,
+ segment=10,
+ use_train_segment=True,
+ ):
+ """
+ Args:
+ sources (list[str]): list of source names.
+ audio_channels (int): input/output audio channels.
+ channels (int): initial number of hidden channels.
+ channels_time: if not None, use a different `channels` value for the time branch.
+ growth: increase the number of hidden channels by this factor at each layer.
+ nfft: number of fft bins. Note that changing this require careful computation of
+ various shape parameters and will not work out of the box for hybrid models.
+ wiener_iters: when using Wiener filtering, number of iterations at test time.
+ end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`.
+ wiener_residual: add residual source before wiener filtering.
+ cac: uses complex as channels, i.e. complex numbers are 2 channels each
+ in input and output. no further processing is done before ISTFT.
+ depth (int): number of layers in the encoder and in the decoder.
+ rewrite (bool): add 1x1 convolution to each layer.
+ multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`.
+ multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost
+ layers will be wrapped.
+ freq_emb: add frequency embedding after the first frequency layer if > 0,
+ the actual value controls the weight of the embedding.
+ emb_scale: equivalent to scaling the embedding learning rate
+ emb_smooth: initialize the embedding with a smooth one (with respect to frequencies).
+ kernel_size: kernel_size for encoder and decoder layers.
+ stride: stride for encoder and decoder layers.
+ time_stride: stride for the final time layer, after the merge.
+ context: context for 1x1 conv in the decoder.
+ context_enc: context for 1x1 conv in the encoder.
+ norm_starts: layer at which group norm starts being used.
+ decoder layers are numbered in reverse order.
+ norm_groups: number of groups for group norm.
+ dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
+ dconv_depth: depth of residual DConv branch.
+ dconv_comp: compression of DConv branch.
+ dconv_attn: adds attention layers in DConv branch starting at this layer.
+ dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
+ dconv_init: initial scale for the DConv branch LayerScale.
+ bottom_channels: if >0 it adds a linear layer (1x1 Conv) before and after the
+ transformer in order to change the number of channels
+ t_layers: number of layers in each branch (waveform and spec) of the transformer
+ t_emb: "sin", "cape" or "scaled"
+ t_hidden_scale: the hidden scale of the Feedforward parts of the transformer
+ for instance if C = 384 (the number of channels in the transformer) and
+ t_hidden_scale = 4.0 then the intermediate layer of the FFN has dimension
+ 384 * 4 = 1536
+ t_heads: number of heads for the transformer
+ t_dropout: dropout in the transformer
+ t_max_positions: max_positions for the "scaled" positional embedding, only
+ useful if t_emb="scaled"
+ t_norm_in: (bool) norm before addinf positional embedding and getting into the
+ transformer layers
+ t_norm_in_group: (bool) if True while t_norm_in=True, the norm is on all the
+ timesteps (GroupNorm with group=1)
+ t_group_norm: (bool) if True, the norms of the Encoder Layers are on all the
+ timesteps (GroupNorm with group=1)
+ t_norm_first: (bool) if True the norm is before the attention and before the FFN
+ t_norm_out: (bool) if True, there is a GroupNorm (group=1) at the end of each layer
+ t_max_period: (float) denominator in the sinusoidal embedding expression
+ t_weight_decay: (float) weight decay for the transformer
+ t_lr: (float) specific learning rate for the transformer
+ t_layer_scale: (bool) Layer Scale for the transformer
+ t_gelu: (bool) activations of the transformer are GeLU if True, ReLU else
+ t_weight_pos_embed: (float) weighting of the positional embedding
+ t_cape_mean_normalize: (bool) if t_emb="cape", normalisation of positional embeddings
+ see: https://arxiv.org/abs/2106.03143
+ t_cape_augment: (bool) if t_emb="cape", must be True during training and False
+ during the inference, see: https://arxiv.org/abs/2106.03143
+ t_cape_glob_loc_scale: (list of 3 floats) if t_emb="cape", CAPE parameters
+ see: https://arxiv.org/abs/2106.03143
+ t_sparse_self_attn: (bool) if True, the self attentions are sparse
+ t_sparse_cross_attn: (bool) if True, the cross-attentions are sparse (don't use it
+ unless you designed really specific masks)
+ t_mask_type: (str) can be "diag", "jmask", "random", "global" or any combination
+ with '_' between: i.e. "diag_jmask_random" (note that this is permutation
+ invariant i.e. "diag_jmask_random" is equivalent to "jmask_random_diag")
+ t_mask_random_seed: (int) if "random" is in t_mask_type, controls the seed
+ that generated the random part of the mask
+ t_sparse_attn_window: (int) if "diag" is in t_mask_type, for a query (i), and
+ a key (j), the mask is True id |i-j|<=t_sparse_attn_window
+ t_global_window: (int) if "global" is in t_mask_type, mask[:t_global_window, :]
+ and mask[:, :t_global_window] will be True
+ t_sparsity: (float) if "random" is in t_mask_type, t_sparsity is the sparsity
+ level of the random part of the mask.
+ t_cross_first: (bool) if True cross attention is the first layer of the
+ transformer (False seems to be better)
+ rescale: weight rescaling trick
+ use_train_segment: (bool) if True, the actual size that is used during the
+ training is used during inference.
+ """
+ super().__init__()
+ self.cac = cac
+ self.wiener_residual = wiener_residual
+ self.audio_channels = audio_channels
+ self.sources = sources
+ self.kernel_size = kernel_size
+ self.context = context
+ self.stride = stride
+ self.depth = depth
+ self.bottom_channels = bottom_channels
+ self.channels = channels
+ self.samplerate = samplerate
+ self.segment = segment
+ self.use_train_segment = use_train_segment
+ self.nfft = nfft
+ self.hop_length = nfft // 4
+ self.wiener_iters = wiener_iters
+ self.end_iters = end_iters
+ self.freq_emb = None
+ assert wiener_iters == end_iters
+
+ self.encoder = nn.ModuleList()
+ self.decoder = nn.ModuleList()
+
+ self.tencoder = nn.ModuleList()
+ self.tdecoder = nn.ModuleList()
+
+ chin = audio_channels
+ chin_z = chin # number of channels for the freq branch
+ if self.cac:
+ chin_z *= 2
+ chout = channels_time or channels
+ chout_z = channels
+ freqs = nfft // 2
+
+ for index in range(depth):
+ norm = index >= norm_starts
+ freq = freqs > 1
+ stri = stride
+ ker = kernel_size
+ if not freq:
+ assert freqs == 1
+ ker = time_stride * 2
+ stri = time_stride
+
+ pad = True
+ last_freq = False
+ if freq and freqs <= kernel_size:
+ ker = freqs
+ pad = False
+ last_freq = True
+
+ kw = {
+ "kernel_size": ker,
+ "stride": stri,
+ "freq": freq,
+ "pad": pad,
+ "norm": norm,
+ "rewrite": rewrite,
+ "norm_groups": norm_groups,
+ "dconv_kw": {
+ "depth": dconv_depth,
+ "compress": dconv_comp,
+ "init": dconv_init,
+ "gelu": True,
+ },
+ }
+ kwt = dict(kw)
+ kwt["freq"] = 0
+ kwt["kernel_size"] = kernel_size
+ kwt["stride"] = stride
+ kwt["pad"] = True
+ kw_dec = dict(kw)
+ multi = False
+ if multi_freqs and index < multi_freqs_depth:
+ multi = True
+ kw_dec["context_freq"] = False
+
+ if last_freq:
+ chout_z = max(chout, chout_z)
+ chout = chout_z
+
+ enc = HEncLayer(
+ chin_z, chout_z, dconv=dconv_mode & 1, context=context_enc, **kw
+ )
+ if freq:
+ tenc = HEncLayer(
+ chin,
+ chout,
+ dconv=dconv_mode & 1,
+ context=context_enc,
+ empty=last_freq,
+ **kwt
+ )
+ self.tencoder.append(tenc)
+
+ if multi:
+ enc = MultiWrap(enc, multi_freqs)
+ self.encoder.append(enc)
+ if index == 0:
+ chin = self.audio_channels * len(self.sources)
+ chin_z = chin
+ if self.cac:
+ chin_z *= 2
+ dec = HDecLayer(
+ chout_z,
+ chin_z,
+ dconv=dconv_mode & 2,
+ last=index == 0,
+ context=context,
+ **kw_dec
+ )
+ if multi:
+ dec = MultiWrap(dec, multi_freqs)
+ if freq:
+ tdec = HDecLayer(
+ chout,
+ chin,
+ dconv=dconv_mode & 2,
+ empty=last_freq,
+ last=index == 0,
+ context=context,
+ **kwt
+ )
+ self.tdecoder.insert(0, tdec)
+ self.decoder.insert(0, dec)
+
+ chin = chout
+ chin_z = chout_z
+ chout = int(growth * chout)
+ chout_z = int(growth * chout_z)
+ if freq:
+ if freqs <= kernel_size:
+ freqs = 1
+ else:
+ freqs //= stride
+ if index == 0 and freq_emb:
+ self.freq_emb = ScaledEmbedding(
+ freqs, chin_z, smooth=emb_smooth, scale=emb_scale
+ )
+ self.freq_emb_scale = freq_emb
+
+ if rescale:
+ rescale_module(self, reference=rescale)
+
+ transformer_channels = channels * growth ** (depth - 1)
+ if bottom_channels:
+ self.channel_upsampler = nn.Conv1d(transformer_channels, bottom_channels, 1)
+ self.channel_downsampler = nn.Conv1d(
+ bottom_channels, transformer_channels, 1
+ )
+ self.channel_upsampler_t = nn.Conv1d(
+ transformer_channels, bottom_channels, 1
+ )
+ self.channel_downsampler_t = nn.Conv1d(
+ bottom_channels, transformer_channels, 1
+ )
+
+ transformer_channels = bottom_channels
+
+ if t_layers > 0:
+ self.crosstransformer = CrossTransformerEncoder(
+ dim=transformer_channels,
+ emb=t_emb,
+ hidden_scale=t_hidden_scale,
+ num_heads=t_heads,
+ num_layers=t_layers,
+ cross_first=t_cross_first,
+ dropout=t_dropout,
+ max_positions=t_max_positions,
+ norm_in=t_norm_in,
+ norm_in_group=t_norm_in_group,
+ group_norm=t_group_norm,
+ norm_first=t_norm_first,
+ norm_out=t_norm_out,
+ max_period=t_max_period,
+ weight_decay=t_weight_decay,
+ lr=t_lr,
+ layer_scale=t_layer_scale,
+ gelu=t_gelu,
+ sin_random_shift=t_sin_random_shift,
+ weight_pos_embed=t_weight_pos_embed,
+ cape_mean_normalize=t_cape_mean_normalize,
+ cape_augment=t_cape_augment,
+ cape_glob_loc_scale=t_cape_glob_loc_scale,
+ sparse_self_attn=t_sparse_self_attn,
+ sparse_cross_attn=t_sparse_cross_attn,
+ mask_type=t_mask_type,
+ mask_random_seed=t_mask_random_seed,
+ sparse_attn_window=t_sparse_attn_window,
+ global_window=t_global_window,
+ sparsity=t_sparsity,
+ auto_sparsity=t_auto_sparsity,
+ )
+ else:
+ self.crosstransformer = None
+
+ def _spec(self, x):
+ hl = self.hop_length
+ nfft = self.nfft
+ x0 = x # noqa
+
+ # We re-pad the signal in order to keep the property
+ # that the size of the output is exactly the size of the input
+ # divided by the stride (here hop_length), when divisible.
+ # This is achieved by padding by 1/4th of the kernel size (here nfft).
+ # which is not supported by torch.stft.
+ # Having all convolution operations follow this convention allow to easily
+ # align the time and frequency branches later on.
+ assert hl == nfft // 4
+ le = int(math.ceil(x.shape[-1] / hl))
+ pad = hl // 2 * 3
+ x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode="reflect")
+
+ z = spectro(x, nfft, hl)[..., :-1, :]
+ assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
+ z = z[..., 2: 2 + le]
+ return z
+
+ def _ispec(self, z, length=None, scale=0):
+ hl = self.hop_length // (4**scale)
+ z = F.pad(z, (0, 0, 0, 1))
+ z = F.pad(z, (2, 2))
+ pad = hl // 2 * 3
+ le = hl * int(math.ceil(length / hl)) + 2 * pad
+ x = ispectro(z, hl, length=le)
+ x = x[..., pad: pad + length]
+ return x
+
+ def _magnitude(self, z):
+ # return the magnitude of the spectrogram, except when cac is True,
+ # in which case we just move the complex dimension to the channel one.
+ if self.cac:
+ B, C, Fr, T = z.shape
+ m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
+ m = m.reshape(B, C * 2, Fr, T)
+ else:
+ m = z.abs()
+ return m
+
+ def _mask(self, z, m):
+ # Apply masking given the mixture spectrogram `z` and the estimated mask `m`.
+ # If `cac` is True, `m` is actually a full spectrogram and `z` is ignored.
+ niters = self.wiener_iters
+ if self.cac:
+ B, S, C, Fr, T = m.shape
+ out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
+ out = torch.view_as_complex(out.contiguous())
+ return out
+ if self.training:
+ niters = self.end_iters
+ if niters < 0:
+ z = z[:, None]
+ return z / (1e-8 + z.abs()) * m
+ else:
+ return self._wiener(m, z, niters)
+
+ def _wiener(self, mag_out, mix_stft, niters):
+ # apply wiener filtering from OpenUnmix.
+ init = mix_stft.dtype
+ wiener_win_len = 300
+ residual = self.wiener_residual
+
+ B, S, C, Fq, T = mag_out.shape
+ mag_out = mag_out.permute(0, 4, 3, 2, 1)
+ mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1))
+
+ outs = []
+ for sample in range(B):
+ pos = 0
+ out = []
+ for pos in range(0, T, wiener_win_len):
+ frame = slice(pos, pos + wiener_win_len)
+ z_out = wiener(
+ mag_out[sample, frame],
+ mix_stft[sample, frame],
+ niters,
+ residual=residual,
+ )
+ out.append(z_out.transpose(-1, -2))
+ outs.append(torch.cat(out, dim=0))
+ out = torch.view_as_complex(torch.stack(outs, 0))
+ out = out.permute(0, 4, 3, 2, 1).contiguous()
+ if residual:
+ out = out[:, :-1]
+ assert list(out.shape) == [B, S, C, Fq, T]
+ return out.to(init)
+
+ def valid_length(self, length: int):
+ """
+ Return a length that is appropriate for evaluation.
+ In our case, always return the training length, unless
+ it is smaller than the given length, in which case this
+ raises an error.
+ """
+ if not self.use_train_segment:
+ return length
+ training_length = int(self.segment * self.samplerate)
+ if training_length < length:
+ raise ValueError(
+ f"Given length {length} is longer than "
+ f"training length {training_length}")
+ return training_length
+
+ def forward(self, mix):
+ length = mix.shape[-1]
+ length_pre_pad = None
+ if self.use_train_segment:
+ if self.training:
+ self.segment = Fraction(mix.shape[-1], self.samplerate)
+ else:
+ training_length = int(self.segment * self.samplerate)
+ if mix.shape[-1] < training_length:
+ length_pre_pad = mix.shape[-1]
+ mix = F.pad(mix, (0, training_length - length_pre_pad))
+ z = self._spec(mix)
+ mag = self._magnitude(z).to(mix.device)
+ x = mag
+
+ B, C, Fq, T = x.shape
+
+ # unlike previous Demucs, we always normalize because it is easier.
+ mean = x.mean(dim=(1, 2, 3), keepdim=True)
+ std = x.std(dim=(1, 2, 3), keepdim=True)
+ x = (x - mean) / (1e-5 + std)
+ # x will be the freq. branch input.
+
+ # Prepare the time branch input.
+ xt = mix
+ meant = xt.mean(dim=(1, 2), keepdim=True)
+ stdt = xt.std(dim=(1, 2), keepdim=True)
+ xt = (xt - meant) / (1e-5 + stdt)
+
+ # okay, this is a giant mess I know...
+ saved = [] # skip connections, freq.
+ saved_t = [] # skip connections, time.
+ lengths = [] # saved lengths to properly remove padding, freq branch.
+ lengths_t = [] # saved lengths for time branch.
+ for idx, encode in enumerate(self.encoder):
+ lengths.append(x.shape[-1])
+ inject = None
+ if idx < len(self.tencoder):
+ # we have not yet merged branches.
+ lengths_t.append(xt.shape[-1])
+ tenc = self.tencoder[idx]
+ xt = tenc(xt)
+ if not tenc.empty:
+ # save for skip connection
+ saved_t.append(xt)
+ else:
+ # tenc contains just the first conv., so that now time and freq.
+ # branches have the same shape and can be merged.
+ inject = xt
+ x = encode(x, inject)
+ if idx == 0 and self.freq_emb is not None:
+ # add frequency embedding to allow for non equivariant convolutions
+ # over the frequency axis.
+ frs = torch.arange(x.shape[-2], device=x.device)
+ emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
+ x = x + self.freq_emb_scale * emb
+
+ saved.append(x)
+ if self.crosstransformer:
+ if self.bottom_channels:
+ b, c, f, t = x.shape
+ x = rearrange(x, "b c f t-> b c (f t)")
+ x = self.channel_upsampler(x)
+ x = rearrange(x, "b c (f t)-> b c f t", f=f)
+ xt = self.channel_upsampler_t(xt)
+
+ x, xt = self.crosstransformer(x, xt)
+
+ if self.bottom_channels:
+ x = rearrange(x, "b c f t-> b c (f t)")
+ x = self.channel_downsampler(x)
+ x = rearrange(x, "b c (f t)-> b c f t", f=f)
+ xt = self.channel_downsampler_t(xt)
+
+ for idx, decode in enumerate(self.decoder):
+ skip = saved.pop(-1)
+ x, pre = decode(x, skip, lengths.pop(-1))
+ # `pre` contains the output just before final transposed convolution,
+ # which is used when the freq. and time branch separate.
+
+ offset = self.depth - len(self.tdecoder)
+ if idx >= offset:
+ tdec = self.tdecoder[idx - offset]
+ length_t = lengths_t.pop(-1)
+ if tdec.empty:
+ assert pre.shape[2] == 1, pre.shape
+ pre = pre[:, :, 0]
+ xt, _ = tdec(pre, None, length_t)
+ else:
+ skip = saved_t.pop(-1)
+ xt, _ = tdec(xt, skip, length_t)
+
+ # Let's make sure we used all stored skip connections.
+ assert len(saved) == 0
+ assert len(lengths_t) == 0
+ assert len(saved_t) == 0
+
+ S = len(self.sources)
+ x = x.view(B, S, -1, Fq, T)
+ x = x * std[:, None] + mean[:, None]
+
+ # to cpu as mps doesnt support complex numbers
+ # demucs issue #435 ##432
+ # NOTE: in this case z already is on cpu
+ # TODO: remove this when mps supports complex numbers
+ x_is_mps = x.device.type == "mps"
+ if x_is_mps:
+ x = x.cpu()
+
+ zout = self._mask(z, x)
+ if self.use_train_segment:
+ if self.training:
+ x = self._ispec(zout, length)
+ else:
+ x = self._ispec(zout, training_length)
+ else:
+ x = self._ispec(zout, length)
+
+ # back to mps device
+ if x_is_mps:
+ x = x.to("mps")
+
+ if self.use_train_segment:
+ if self.training:
+ xt = xt.view(B, S, -1, length)
+ else:
+ xt = xt.view(B, S, -1, training_length)
+ else:
+ xt = xt.view(B, S, -1, length)
+ xt = xt * stdt[:, None] + meant[:, None]
+ x = xt + x
+ if length_pre_pad:
+ x = x[..., :length_pre_pad]
+ return x
diff --git a/AIMeiSheng/demucs/pretrained.py b/AIMeiSheng/demucs/pretrained.py
new file mode 100644
index 0000000..80ae49c
--- /dev/null
+++ b/AIMeiSheng/demucs/pretrained.py
@@ -0,0 +1,98 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Loading pretrained models.
+"""
+
+import logging
+from pathlib import Path
+import typing as tp
+
+from dora.log import fatal, bold
+
+from .hdemucs import HDemucs
+from .repo import RemoteRepo, LocalRepo, ModelOnlyRepo, BagOnlyRepo, AnyModelRepo, ModelLoadingError # noqa
+from .states import _check_diffq
+
+logger = logging.getLogger(__name__)
+ROOT_URL = "https://dl.fbaipublicfiles.com/demucs/"
+REMOTE_ROOT = Path(__file__).parent / 'remote'
+
+SOURCES = ["drums", "bass", "other", "vocals"]
+DEFAULT_MODEL = 'htdemucs'
+
+
+def demucs_unittest():
+ model = HDemucs(channels=4, sources=SOURCES)
+ return model
+
+
+def add_model_flags(parser):
+ group = parser.add_mutually_exclusive_group(required=False)
+ group.add_argument("-s", "--sig", help="Locally trained XP signature.")
+ group.add_argument("-n", "--name", default="htdemucs",
+ help="Pretrained model name or signature. Default is htdemucs.")
+ parser.add_argument("--repo", type=Path,
+ help="Folder containing all pre-trained models for use with -n.")
+
+
+def _parse_remote_files(remote_file_list) -> tp.Dict[str, str]:
+ root: str = ''
+ models: tp.Dict[str, str] = {}
+ for line in remote_file_list.read_text().split('\n'):
+ line = line.strip()
+ if line.startswith('#'):
+ continue
+ elif len(line) == 0:
+ continue
+ elif line.startswith('root:'):
+ root = line.split(':', 1)[1].strip()
+ else:
+ sig = line.split('-', 1)[0]
+ assert sig not in models
+ models[sig] = ROOT_URL + root + line
+ return models
+
+
+def get_model(name: str,
+ repo: tp.Optional[Path] = None):
+ """`name` must be a bag of models name or a pretrained signature
+ from the remote AWS model repo or the specified local repo if `repo` is not None.
+ """
+ if name == 'demucs_unittest':
+ return demucs_unittest()
+ model_repo: ModelOnlyRepo
+ if repo is None:
+ models = _parse_remote_files(REMOTE_ROOT / 'files.txt')
+ model_repo = RemoteRepo(models)
+ bag_repo = BagOnlyRepo(REMOTE_ROOT, model_repo)
+ else:
+ if not repo.is_dir():
+ fatal(f"{repo} must exist and be a directory.")
+ model_repo = LocalRepo(repo)
+ bag_repo = BagOnlyRepo(repo, model_repo)
+ any_repo = AnyModelRepo(model_repo, bag_repo)
+ try:
+ model = any_repo.get_model(name)
+ except ImportError as exc:
+ if 'diffq' in exc.args[0]:
+ _check_diffq()
+ raise
+
+ model.eval()
+ return model
+
+
+def get_model_from_args(args):
+ """
+ Load local model package or pre-trained model.
+ """
+ if args.name is None:
+ args.name = DEFAULT_MODEL
+ print(bold("Important: the default model was recently changed to `htdemucs`"),
+ "the latest Hybrid Transformer Demucs model. In some cases, this model can "
+ "actually perform worse than previous models. To get back the old default model "
+ "use `-n mdx_extra_q`.")
+ return get_model(name=args.name, repo=args.repo)
diff --git a/AIMeiSheng/demucs/py.typed b/AIMeiSheng/demucs/py.typed
new file mode 100644
index 0000000..e69de29
diff --git a/AIMeiSheng/demucs/remote/files.txt b/AIMeiSheng/demucs/remote/files.txt
new file mode 100644
index 0000000..346eb33
--- /dev/null
+++ b/AIMeiSheng/demucs/remote/files.txt
@@ -0,0 +1,32 @@
+# MDX Models
+root: mdx_final/
+0d19c1c6-0f06f20e.th
+5d2d6c55-db83574e.th
+7d865c68-3d5dd56b.th
+7ecf8ec1-70f50cc9.th
+a1d90b5c-ae9d2452.th
+c511e2ab-fe698775.th
+cfa93e08-61801ae1.th
+e51eebcc-c1b80bdd.th
+6b9c2ca1-3fd82607.th
+b72baf4e-8778635e.th
+42e558d4-196e0e1b.th
+305bc58f-18378783.th
+14fc6a69-a89dd0ee.th
+464b36d7-e5a9386e.th
+7fd6ef75-a905dd85.th
+83fc094f-4a16d450.th
+1ef250f1-592467ce.th
+902315c2-b39ce9c9.th
+9a6b4851-03af0aa6.th
+fa0cb7f9-100d8bf4.th
+# Hybrid Transformer models
+root: hybrid_transformer/
+955717e8-8726e21a.th
+f7e0c4bc-ba3fe64a.th
+d12395a8-e57c48e6.th
+92cfc3b6-ef3bcb9c.th
+04573f0d-f3cf25b2.th
+75fc33f5-1941ce65.th
+# Experimental 6 sources model
+5c90dfd2-34c22ccb.th
diff --git a/AIMeiSheng/demucs/remote/hdemucs_mmi.yaml b/AIMeiSheng/demucs/remote/hdemucs_mmi.yaml
new file mode 100644
index 0000000..0ea0891
--- /dev/null
+++ b/AIMeiSheng/demucs/remote/hdemucs_mmi.yaml
@@ -0,0 +1,2 @@
+models: ['75fc33f5']
+segment: 44
diff --git a/AIMeiSheng/demucs/remote/htdemucs.yaml b/AIMeiSheng/demucs/remote/htdemucs.yaml
new file mode 100644
index 0000000..0d5f208
--- /dev/null
+++ b/AIMeiSheng/demucs/remote/htdemucs.yaml
@@ -0,0 +1 @@
+models: ['955717e8']
diff --git a/AIMeiSheng/demucs/remote/htdemucs_6s.yaml b/AIMeiSheng/demucs/remote/htdemucs_6s.yaml
new file mode 100644
index 0000000..651a0fa
--- /dev/null
+++ b/AIMeiSheng/demucs/remote/htdemucs_6s.yaml
@@ -0,0 +1 @@
+models: ['5c90dfd2']
diff --git a/AIMeiSheng/demucs/remote/htdemucs_ft.yaml b/AIMeiSheng/demucs/remote/htdemucs_ft.yaml
new file mode 100644
index 0000000..ba5c69c
--- /dev/null
+++ b/AIMeiSheng/demucs/remote/htdemucs_ft.yaml
@@ -0,0 +1,7 @@
+models: ['f7e0c4bc', 'd12395a8', '92cfc3b6', '04573f0d']
+weights: [
+ [1., 0., 0., 0.],
+ [0., 1., 0., 0.],
+ [0., 0., 1., 0.],
+ [0., 0., 0., 1.],
+]
\ No newline at end of file
diff --git a/AIMeiSheng/demucs/remote/mdx.yaml b/AIMeiSheng/demucs/remote/mdx.yaml
new file mode 100644
index 0000000..4e81a50
--- /dev/null
+++ b/AIMeiSheng/demucs/remote/mdx.yaml
@@ -0,0 +1,8 @@
+models: ['0d19c1c6', '7ecf8ec1', 'c511e2ab', '7d865c68']
+weights: [
+ [1., 1., 0., 0.],
+ [0., 1., 0., 0.],
+ [1., 0., 1., 1.],
+ [1., 0., 1., 1.],
+]
+segment: 44
diff --git a/AIMeiSheng/demucs/remote/mdx_extra.yaml b/AIMeiSheng/demucs/remote/mdx_extra.yaml
new file mode 100644
index 0000000..847bf66
--- /dev/null
+++ b/AIMeiSheng/demucs/remote/mdx_extra.yaml
@@ -0,0 +1,2 @@
+models: ['e51eebcc', 'a1d90b5c', '5d2d6c55', 'cfa93e08']
+segment: 44
\ No newline at end of file
diff --git a/AIMeiSheng/demucs/remote/mdx_extra_q.yaml b/AIMeiSheng/demucs/remote/mdx_extra_q.yaml
new file mode 100644
index 0000000..87702bc
--- /dev/null
+++ b/AIMeiSheng/demucs/remote/mdx_extra_q.yaml
@@ -0,0 +1,2 @@
+models: ['83fc094f', '464b36d7', '14fc6a69', '7fd6ef75']
+segment: 44
diff --git a/AIMeiSheng/demucs/remote/mdx_q.yaml b/AIMeiSheng/demucs/remote/mdx_q.yaml
new file mode 100644
index 0000000..827d2c6
--- /dev/null
+++ b/AIMeiSheng/demucs/remote/mdx_q.yaml
@@ -0,0 +1,8 @@
+models: ['6b9c2ca1', 'b72baf4e', '42e558d4', '305bc58f']
+weights: [
+ [1., 1., 0., 0.],
+ [0., 1., 0., 0.],
+ [1., 0., 1., 1.],
+ [1., 0., 1., 1.],
+]
+segment: 44
diff --git a/AIMeiSheng/demucs/remote/repro_mdx_a.yaml b/AIMeiSheng/demucs/remote/repro_mdx_a.yaml
new file mode 100644
index 0000000..691abc2
--- /dev/null
+++ b/AIMeiSheng/demucs/remote/repro_mdx_a.yaml
@@ -0,0 +1,2 @@
+models: ['9a6b4851', '1ef250f1', 'fa0cb7f9', '902315c2']
+segment: 44
diff --git a/AIMeiSheng/demucs/remote/repro_mdx_a_hybrid_only.yaml b/AIMeiSheng/demucs/remote/repro_mdx_a_hybrid_only.yaml
new file mode 100644
index 0000000..78eb8e0
--- /dev/null
+++ b/AIMeiSheng/demucs/remote/repro_mdx_a_hybrid_only.yaml
@@ -0,0 +1,2 @@
+models: ['fa0cb7f9', '902315c2', 'fa0cb7f9', '902315c2']
+segment: 44
diff --git a/AIMeiSheng/demucs/remote/repro_mdx_a_time_only.yaml b/AIMeiSheng/demucs/remote/repro_mdx_a_time_only.yaml
new file mode 100644
index 0000000..d5d16ea
--- /dev/null
+++ b/AIMeiSheng/demucs/remote/repro_mdx_a_time_only.yaml
@@ -0,0 +1,2 @@
+models: ['9a6b4851', '9a6b4851', '1ef250f1', '1ef250f1']
+segment: 44
diff --git a/AIMeiSheng/demucs/repitch.py b/AIMeiSheng/demucs/repitch.py
new file mode 100644
index 0000000..ebef736
--- /dev/null
+++ b/AIMeiSheng/demucs/repitch.py
@@ -0,0 +1,86 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Utility for on the fly pitch/tempo change for data augmentation."""
+
+import random
+import subprocess as sp
+import tempfile
+
+import torch
+import torchaudio as ta
+
+from .audio import save_audio
+
+
+class RepitchedWrapper:
+ """
+ Wrap a dataset to apply online change of pitch / tempo.
+ """
+ def __init__(self, dataset, proba=0.2, max_pitch=2, max_tempo=12,
+ tempo_std=5, vocals=[3], same=True):
+ self.dataset = dataset
+ self.proba = proba
+ self.max_pitch = max_pitch
+ self.max_tempo = max_tempo
+ self.tempo_std = tempo_std
+ self.same = same
+ self.vocals = vocals
+
+ def __len__(self):
+ return len(self.dataset)
+
+ def __getitem__(self, index):
+ streams = self.dataset[index]
+ in_length = streams.shape[-1]
+ out_length = int((1 - 0.01 * self.max_tempo) * in_length)
+
+ if random.random() < self.proba:
+ outs = []
+ for idx, stream in enumerate(streams):
+ if idx == 0 or not self.same:
+ delta_pitch = random.randint(-self.max_pitch, self.max_pitch)
+ delta_tempo = random.gauss(0, self.tempo_std)
+ delta_tempo = min(max(-self.max_tempo, delta_tempo), self.max_tempo)
+ stream = repitch(
+ stream,
+ delta_pitch,
+ delta_tempo,
+ voice=idx in self.vocals)
+ outs.append(stream[:, :out_length])
+ streams = torch.stack(outs)
+ else:
+ streams = streams[..., :out_length]
+ return streams
+
+
+def repitch(wav, pitch, tempo, voice=False, quick=False, samplerate=44100):
+ """
+ tempo is a relative delta in percentage, so tempo=10 means tempo at 110%!
+ pitch is in semi tones.
+ Requires `soundstretch` to be installed, see
+ https://www.surina.net/soundtouch/soundstretch.html
+ """
+ infile = tempfile.NamedTemporaryFile(suffix=".wav")
+ outfile = tempfile.NamedTemporaryFile(suffix=".wav")
+ save_audio(wav, infile.name, samplerate, clip='clamp')
+ command = [
+ "soundstretch",
+ infile.name,
+ outfile.name,
+ f"-pitch={pitch}",
+ f"-tempo={tempo:.6f}",
+ ]
+ if quick:
+ command += ["-quick"]
+ if voice:
+ command += ["-speech"]
+ try:
+ sp.run(command, capture_output=True, check=True)
+ except sp.CalledProcessError as error:
+ raise RuntimeError(f"Could not change bpm because {error.stderr.decode('utf-8')}")
+ wav, sr = ta.load(outfile.name)
+ assert sr == samplerate
+ return wav
diff --git a/AIMeiSheng/demucs/repo.py b/AIMeiSheng/demucs/repo.py
new file mode 100644
index 0000000..75f2afa
--- /dev/null
+++ b/AIMeiSheng/demucs/repo.py
@@ -0,0 +1,180 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Represents a model repository, including pre-trained models and bags of models.
+A repo can either be the main remote repository stored in AWS, or a local repository
+with your own models.
+"""
+
+from hashlib import sha256
+from pathlib import Path
+import typing as tp
+
+import torch
+import yaml
+import os
+
+from .apply import BagOfModels, Model
+from .states import load_model
+from AIMeiSheng.docker_demo.common import gs_demucs_model_path
+
+AnyModel = tp.Union[Model, BagOfModels]
+
+
+class ModelLoadingError(RuntimeError):
+ pass
+
+
+def check_checksum(path: Path, checksum: str):
+ sha = sha256()
+ with open(path, 'rb') as file:
+ while True:
+ buf = file.read(2**20)
+ if not buf:
+ break
+ sha.update(buf)
+ actual_checksum = sha.hexdigest()[:len(checksum)]
+ if actual_checksum != checksum:
+ raise ModelLoadingError(f'Invalid checksum for file {path}, '
+ f'expected {checksum} but got {actual_checksum}')
+
+
+class ModelOnlyRepo:
+ """Base class for all model only repos.
+ """
+ def has_model(self, sig: str) -> bool:
+ raise NotImplementedError()
+
+ def get_model(self, sig: str) -> Model:
+ raise NotImplementedError()
+
+ def list_model(self) -> tp.Dict[str, tp.Union[str, Path]]:
+ raise NotImplementedError()
+
+
+class RemoteRepo(ModelOnlyRepo):
+ def __init__(self, models: tp.Dict[str, str]):
+ self._models = models
+
+ def has_model(self, sig: str) -> bool:
+ return sig in self._models
+
+ def get_model(self, sig: str) -> Model:
+ try:
+ url = self._models[sig]
+ except KeyError:
+ raise ModelLoadingError(f'Could not find a pre-trained model with signature {sig}.')
+ #'''
+ path_dir = gs_demucs_model_path #"/data/bingxiao.fang/voice_conversion/svc_meisheng/AIMeiSheng/demucs_model"
+ url_model = url.split('/')[-1]
+ path_url_model = os.path.join(path_dir, url_model)
+ print("@@@@path_url_model:", path_url_model)
+ pkg = torch.load(path_url_model)
+ #'''
+
+ '''
+ print("@@@@url:", url)
+ pkg = torch.hub.load_state_dict_from_url(
+ url, map_location='cpu', check_hash=True) # type: ignore
+
+ #'''
+ #print("@@@@@pkg:",pkg)
+ return load_model(pkg)
+
+ def list_model(self) -> tp.Dict[str, tp.Union[str, Path]]:
+ return self._models # type: ignore
+
+
+class LocalRepo(ModelOnlyRepo):
+ def __init__(self, root: Path):
+ self.root = root
+ self.scan()
+
+ def scan(self):
+ self._models = {}
+ self._checksums = {}
+ for file in self.root.iterdir():
+ if file.suffix == '.th':
+ if '-' in file.stem:
+ xp_sig, checksum = file.stem.split('-')
+ self._checksums[xp_sig] = checksum
+ else:
+ xp_sig = file.stem
+ if xp_sig in self._models:
+ raise ModelLoadingError(
+ f'Duplicate pre-trained model exist for signature {xp_sig}. '
+ 'Please delete all but one.')
+ self._models[xp_sig] = file
+
+ def has_model(self, sig: str) -> bool:
+ return sig in self._models
+
+ def get_model(self, sig: str) -> Model:
+ try:
+ file = self._models[sig]
+ except KeyError:
+ raise ModelLoadingError(f'Could not find pre-trained model with signature {sig}.')
+ if sig in self._checksums:
+ check_checksum(file, self._checksums[sig])
+ return load_model(file)
+
+ def list_model(self) -> tp.Dict[str, tp.Union[str, Path]]:
+ return self._models
+
+
+class BagOnlyRepo:
+ """Handles only YAML files containing bag of models, leaving the actual
+ model loading to some Repo.
+ """
+ def __init__(self, root: Path, model_repo: ModelOnlyRepo):
+ self.root = root
+ self.model_repo = model_repo
+ self.scan()
+
+ def scan(self):
+ self._bags = {}
+ for file in self.root.iterdir():
+ if file.suffix == '.yaml':
+ self._bags[file.stem] = file
+
+ def has_model(self, name: str) -> bool:
+ return name in self._bags
+
+ def get_model(self, name: str) -> BagOfModels:
+ try:
+ yaml_file = self._bags[name]
+ except KeyError:
+ raise ModelLoadingError(f'{name} is neither a single pre-trained model or '
+ 'a bag of models.')
+ bag = yaml.safe_load(open(yaml_file))
+ signatures = bag['models']
+ models = [self.model_repo.get_model(sig) for sig in signatures]
+ weights = bag.get('weights')
+ segment = bag.get('segment')
+ return BagOfModels(models, weights, segment)
+
+ def list_model(self) -> tp.Dict[str, tp.Union[str, Path]]:
+ return self._bags
+
+
+class AnyModelRepo:
+ def __init__(self, model_repo: ModelOnlyRepo, bag_repo: BagOnlyRepo):
+ self.model_repo = model_repo
+ self.bag_repo = bag_repo
+
+ def has_model(self, name_or_sig: str) -> bool:
+ return self.model_repo.has_model(name_or_sig) or self.bag_repo.has_model(name_or_sig)
+
+ def get_model(self, name_or_sig: str) -> AnyModel:
+ if self.model_repo.has_model(name_or_sig):
+ return self.model_repo.get_model(name_or_sig)
+ else:
+ return self.bag_repo.get_model(name_or_sig)
+
+ def list_model(self) -> tp.Dict[str, tp.Union[str, Path]]:
+ models = self.model_repo.list_model()
+ for key, value in self.bag_repo.list_model().items():
+ models[key] = value
+ return models
diff --git a/AIMeiSheng/demucs/separate.py b/AIMeiSheng/demucs/separate.py
new file mode 100644
index 0000000..d5102ed
--- /dev/null
+++ b/AIMeiSheng/demucs/separate.py
@@ -0,0 +1,222 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import sys
+from pathlib import Path
+
+from dora.log import fatal
+import torch as th
+
+from .api import Separator, save_audio, list_models
+
+from .apply import BagOfModels
+from .htdemucs import HTDemucs
+from .pretrained import add_model_flags, ModelLoadingError
+
+
+def get_parser():
+ parser = argparse.ArgumentParser("demucs.separate",
+ description="Separate the sources for the given tracks")
+ parser.add_argument("tracks", nargs='*', type=Path, default=[], help='Path to tracks')
+ add_model_flags(parser)
+ parser.add_argument("--list-models", action="store_true", help="List available models "
+ "from current repo and exit")
+ parser.add_argument("-v", "--verbose", action="store_true")
+ parser.add_argument("-o",
+ "--out",
+ type=Path,
+ default=Path("separated"),
+ help="Folder where to put extracted tracks. A subfolder "
+ "with the model name will be created.")
+ parser.add_argument("--filename",
+ default="{track}/{stem}.{ext}",
+ help="Set the name of output file. \n"
+ 'Use "{track}", "{trackext}", "{stem}", "{ext}" to use '
+ "variables of track name without extension, track extension, "
+ "stem name and default output file extension. \n"
+ 'Default is "{track}/{stem}.{ext}".')
+ parser.add_argument("-d",
+ "--device",
+ default="cuda" if th.cuda.is_available() else "cpu",
+ help="Device to use, default is cuda if available else cpu")
+ parser.add_argument("--shifts",
+ default=1,
+ type=int,
+ help="Number of random shifts for equivariant stabilization."
+ "Increase separation time but improves quality for Demucs. 10 was used "
+ "in the original paper.")
+ parser.add_argument("--overlap",
+ default=0.25,
+ type=float,
+ help="Overlap between the splits.")
+ split_group = parser.add_mutually_exclusive_group()
+ split_group.add_argument("--no-split",
+ action="store_false",
+ dest="split",
+ default=True,
+ help="Doesn't split audio in chunks. "
+ "This can use large amounts of memory.")
+ split_group.add_argument("--segment", type=int,
+ help="Set split size of each chunk. "
+ "This can help save memory of graphic card. ")
+ parser.add_argument("--two-stems",
+ dest="stem", metavar="STEM",
+ help="Only separate audio into {STEM} and no_{STEM}. ")
+ parser.add_argument("--other-method", dest="other_method", choices=["none", "add", "minus"],
+ default="add", help='Decide how to get "no_{STEM}". "none" will not save '
+ '"no_{STEM}". "add" will add all the other stems. "minus" will use the '
+ "original track minus the selected stem.")
+ depth_group = parser.add_mutually_exclusive_group()
+ depth_group.add_argument("--int24", action="store_true",
+ help="Save wav output as 24 bits wav.")
+ depth_group.add_argument("--float32", action="store_true",
+ help="Save wav output as float32 (2x bigger).")
+ parser.add_argument("--clip-mode", default="rescale", choices=["rescale", "clamp", "none"],
+ help="Strategy for avoiding clipping: rescaling entire signal "
+ "if necessary (rescale) or hard clipping (clamp).")
+ format_group = parser.add_mutually_exclusive_group()
+ format_group.add_argument("--flac", action="store_true",
+ help="Convert the output wavs to flac.")
+ format_group.add_argument("--mp3", action="store_true",
+ help="Convert the output wavs to mp3.")
+ parser.add_argument("--mp3-bitrate",
+ default=320,
+ type=int,
+ help="Bitrate of converted mp3.")
+ parser.add_argument("--mp3-preset", choices=range(2, 8), type=int, default=2,
+ help="Encoder preset of MP3, 2 for highest quality, 7 for "
+ "fastest speed. Default is 2")
+ parser.add_argument("-j", "--jobs",
+ default=0,
+ type=int,
+ help="Number of jobs. This can increase memory usage but will "
+ "be much faster when multiple cores are available.")
+
+ return parser
+
+
+def main(opts=None):
+ parser = get_parser()
+ args = parser.parse_args(opts)
+ if args.list_models:
+ models = list_models(args.repo)
+ print("Bag of models:", end="\n ")
+ print("\n ".join(models["bag"]))
+ print("Single models:", end="\n ")
+ print("\n ".join(models["single"]))
+ sys.exit(0)
+ if len(args.tracks) == 0:
+ print("error: the following arguments are required: tracks", file=sys.stderr)
+ sys.exit(1)
+
+ try:
+ separator = Separator(model=args.name,
+ repo=args.repo,
+ device=args.device,
+ shifts=args.shifts,
+ split=args.split,
+ overlap=args.overlap,
+ progress=True,
+ jobs=args.jobs,
+ segment=args.segment)
+ except ModelLoadingError as error:
+ fatal(error.args[0])
+
+ max_allowed_segment = float('inf')
+ if isinstance(separator.model, HTDemucs):
+ max_allowed_segment = float(separator.model.segment)
+ elif isinstance(separator.model, BagOfModels):
+ max_allowed_segment = separator.model.max_allowed_segment
+ if args.segment is not None and args.segment > max_allowed_segment:
+ fatal("Cannot use a Transformer model with a longer segment "
+ f"than it was trained for. Maximum segment is: {max_allowed_segment}")
+
+ if isinstance(separator.model, BagOfModels):
+ print(
+ f"Selected model is a bag of {len(separator.model.models)} models. "
+ "You will see that many progress bars per track."
+ )
+
+ if args.stem is not None and args.stem not in separator.model.sources:
+ fatal(
+ 'error: stem "{stem}" is not in selected model. '
+ "STEM must be one of {sources}.".format(
+ stem=args.stem, sources=", ".join(separator.model.sources)
+ )
+ )
+ out = args.out / args.name
+ out.mkdir(parents=True, exist_ok=True)
+ print(f"Separated tracks will be stored in {out.resolve()}")
+ for track in args.tracks:
+ if not track.exists():
+ print(f"File {track} does not exist. If the path contains spaces, "
+ 'please try again after surrounding the entire path with quotes "".',
+ file=sys.stderr)
+ continue
+ print(f"Separating track {track}")
+
+ origin, res = separator.separate_audio_file(track)
+
+ if args.mp3:
+ ext = "mp3"
+ elif args.flac:
+ ext = "flac"
+ else:
+ ext = "wav"
+ kwargs = {
+ "samplerate": separator.samplerate,
+ "bitrate": args.mp3_bitrate,
+ "preset": args.mp3_preset,
+ "clip": args.clip_mode,
+ "as_float": args.float32,
+ "bits_per_sample": 24 if args.int24 else 16,
+ }
+ if args.stem is None:
+ for name, source in res.items():
+ stem = out / args.filename.format(
+ track=track.name.rsplit(".", 1)[0],
+ trackext=track.name.rsplit(".", 1)[-1],
+ stem=name,
+ ext=ext,
+ )
+ stem.parent.mkdir(parents=True, exist_ok=True)
+ save_audio(source, str(stem), **kwargs)
+ else:
+ stem = out / args.filename.format(
+ track=track.name.rsplit(".", 1)[0],
+ trackext=track.name.rsplit(".", 1)[-1],
+ stem="minus_" + args.stem,
+ ext=ext,
+ )
+ if args.other_method == "minus":
+ stem.parent.mkdir(parents=True, exist_ok=True)
+ save_audio(origin - res[args.stem], str(stem), **kwargs)
+ stem = out / args.filename.format(
+ track=track.name.rsplit(".", 1)[0],
+ trackext=track.name.rsplit(".", 1)[-1],
+ stem=args.stem,
+ ext=ext,
+ )
+ stem.parent.mkdir(parents=True, exist_ok=True)
+ save_audio(res.pop(args.stem), str(stem), **kwargs)
+ # Warning : after poping the stem, selected stem is no longer in the dict 'res'
+ if args.other_method == "add":
+ other_stem = th.zeros_like(next(iter(res.values())))
+ for i in res.values():
+ other_stem += i
+ stem = out / args.filename.format(
+ track=track.name.rsplit(".", 1)[0],
+ trackext=track.name.rsplit(".", 1)[-1],
+ stem="no_" + args.stem,
+ ext=ext,
+ )
+ stem.parent.mkdir(parents=True, exist_ok=True)
+ save_audio(other_stem, str(stem), **kwargs)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/AIMeiSheng/demucs/solver.py b/AIMeiSheng/demucs/solver.py
new file mode 100644
index 0000000..7c80b14
--- /dev/null
+++ b/AIMeiSheng/demucs/solver.py
@@ -0,0 +1,405 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Main training loop."""
+
+import logging
+
+from dora import get_xp
+from dora.utils import write_and_rename
+from dora.log import LogProgress, bold
+import torch
+import torch.nn.functional as F
+
+from . import augment, distrib, states, pretrained
+from .apply import apply_model
+from .ema import ModelEMA
+from .evaluate import evaluate, new_sdr
+from .svd import svd_penalty
+from .utils import pull_metric, EMA
+
+logger = logging.getLogger(__name__)
+
+
+def _summary(metrics):
+ return " | ".join(f"{key.capitalize()}={val}" for key, val in metrics.items())
+
+
+class Solver(object):
+ def __init__(self, loaders, model, optimizer, args):
+ self.args = args
+ self.loaders = loaders
+
+ self.model = model
+ self.optimizer = optimizer
+ self.quantizer = states.get_quantizer(self.model, args.quant, self.optimizer)
+ self.dmodel = distrib.wrap(model)
+ self.device = next(iter(self.model.parameters())).device
+
+ # Exponential moving average of the model, either updated every batch or epoch.
+ # The best model from all the EMAs and the original one is kept based on the valid
+ # loss for the final best model.
+ self.emas = {'batch': [], 'epoch': []}
+ for kind in self.emas.keys():
+ decays = getattr(args.ema, kind)
+ device = self.device if kind == 'batch' else 'cpu'
+ if decays:
+ for decay in decays:
+ self.emas[kind].append(ModelEMA(self.model, decay, device=device))
+
+ # data augment
+ augments = [augment.Shift(shift=int(args.dset.samplerate * args.dset.shift),
+ same=args.augment.shift_same)]
+ if args.augment.flip:
+ augments += [augment.FlipChannels(), augment.FlipSign()]
+ for aug in ['scale', 'remix']:
+ kw = getattr(args.augment, aug)
+ if kw.proba:
+ augments.append(getattr(augment, aug.capitalize())(**kw))
+ self.augment = torch.nn.Sequential(*augments)
+
+ xp = get_xp()
+ self.folder = xp.folder
+ # Checkpoints
+ self.checkpoint_file = xp.folder / 'checkpoint.th'
+ self.best_file = xp.folder / 'best.th'
+ logger.debug("Checkpoint will be saved to %s", self.checkpoint_file.resolve())
+ self.best_state = None
+ self.best_changed = False
+
+ self.link = xp.link
+ self.history = self.link.history
+
+ self._reset()
+
+ def _serialize(self, epoch):
+ package = {}
+ package['state'] = self.model.state_dict()
+ package['optimizer'] = self.optimizer.state_dict()
+ package['history'] = self.history
+ package['best_state'] = self.best_state
+ package['args'] = self.args
+ for kind, emas in self.emas.items():
+ for k, ema in enumerate(emas):
+ package[f'ema_{kind}_{k}'] = ema.state_dict()
+ with write_and_rename(self.checkpoint_file) as tmp:
+ torch.save(package, tmp)
+
+ save_every = self.args.save_every
+ if save_every and (epoch + 1) % save_every == 0 and epoch + 1 != self.args.epochs:
+ with write_and_rename(self.folder / f'checkpoint_{epoch + 1}.th') as tmp:
+ torch.save(package, tmp)
+
+ if self.best_changed:
+ # Saving only the latest best model.
+ with write_and_rename(self.best_file) as tmp:
+ package = states.serialize_model(self.model, self.args)
+ package['state'] = self.best_state
+ torch.save(package, tmp)
+ self.best_changed = False
+
+ def _reset(self):
+ """Reset state of the solver, potentially using checkpoint."""
+ if self.checkpoint_file.exists():
+ logger.info(f'Loading checkpoint model: {self.checkpoint_file}')
+ package = torch.load(self.checkpoint_file, 'cpu')
+ self.model.load_state_dict(package['state'])
+ self.optimizer.load_state_dict(package['optimizer'])
+ self.history[:] = package['history']
+ self.best_state = package['best_state']
+ for kind, emas in self.emas.items():
+ for k, ema in enumerate(emas):
+ ema.load_state_dict(package[f'ema_{kind}_{k}'])
+ elif self.args.continue_pretrained:
+ model = pretrained.get_model(
+ name=self.args.continue_pretrained,
+ repo=self.args.pretrained_repo)
+ self.model.load_state_dict(model.state_dict())
+ elif self.args.continue_from:
+ name = 'checkpoint.th'
+ root = self.folder.parent
+ cf = root / str(self.args.continue_from) / name
+ logger.info("Loading from %s", cf)
+ package = torch.load(cf, 'cpu')
+ self.best_state = package['best_state']
+ if self.args.continue_best:
+ self.model.load_state_dict(package['best_state'], strict=False)
+ else:
+ self.model.load_state_dict(package['state'], strict=False)
+ if self.args.continue_opt:
+ self.optimizer.load_state_dict(package['optimizer'])
+
+ def _format_train(self, metrics: dict) -> dict:
+ """Formatting for train/valid metrics."""
+ losses = {
+ 'loss': format(metrics['loss'], ".4f"),
+ 'reco': format(metrics['reco'], ".4f"),
+ }
+ if 'nsdr' in metrics:
+ losses['nsdr'] = format(metrics['nsdr'], ".3f")
+ if self.quantizer is not None:
+ losses['ms'] = format(metrics['ms'], ".2f")
+ if 'grad' in metrics:
+ losses['grad'] = format(metrics['grad'], ".4f")
+ if 'best' in metrics:
+ losses['best'] = format(metrics['best'], '.4f')
+ if 'bname' in metrics:
+ losses['bname'] = metrics['bname']
+ if 'penalty' in metrics:
+ losses['penalty'] = format(metrics['penalty'], ".4f")
+ if 'hloss' in metrics:
+ losses['hloss'] = format(metrics['hloss'], ".4f")
+ return losses
+
+ def _format_test(self, metrics: dict) -> dict:
+ """Formatting for test metrics."""
+ losses = {}
+ if 'sdr' in metrics:
+ losses['sdr'] = format(metrics['sdr'], '.3f')
+ if 'nsdr' in metrics:
+ losses['nsdr'] = format(metrics['nsdr'], '.3f')
+ for source in self.model.sources:
+ key = f'sdr_{source}'
+ if key in metrics:
+ losses[key] = format(metrics[key], '.3f')
+ key = f'nsdr_{source}'
+ if key in metrics:
+ losses[key] = format(metrics[key], '.3f')
+ return losses
+
+ def train(self):
+ # Optimizing the model
+ if self.history:
+ logger.info("Replaying metrics from previous run")
+ for epoch, metrics in enumerate(self.history):
+ formatted = self._format_train(metrics['train'])
+ logger.info(
+ bold(f'Train Summary | Epoch {epoch + 1} | {_summary(formatted)}'))
+ formatted = self._format_train(metrics['valid'])
+ logger.info(
+ bold(f'Valid Summary | Epoch {epoch + 1} | {_summary(formatted)}'))
+ if 'test' in metrics:
+ formatted = self._format_test(metrics['test'])
+ if formatted:
+ logger.info(bold(f"Test Summary | Epoch {epoch + 1} | {_summary(formatted)}"))
+
+ epoch = 0
+ for epoch in range(len(self.history), self.args.epochs):
+ # Train one epoch
+ self.model.train() # Turn on BatchNorm & Dropout
+ metrics = {}
+ logger.info('-' * 70)
+ logger.info("Training...")
+ metrics['train'] = self._run_one_epoch(epoch)
+ formatted = self._format_train(metrics['train'])
+ logger.info(
+ bold(f'Train Summary | Epoch {epoch + 1} | {_summary(formatted)}'))
+
+ # Cross validation
+ logger.info('-' * 70)
+ logger.info('Cross validation...')
+ self.model.eval() # Turn off Batchnorm & Dropout
+ with torch.no_grad():
+ valid = self._run_one_epoch(epoch, train=False)
+ bvalid = valid
+ bname = 'main'
+ state = states.copy_state(self.model.state_dict())
+ metrics['valid'] = {}
+ metrics['valid']['main'] = valid
+ key = self.args.test.metric
+ for kind, emas in self.emas.items():
+ for k, ema in enumerate(emas):
+ with ema.swap():
+ valid = self._run_one_epoch(epoch, train=False)
+ name = f'ema_{kind}_{k}'
+ metrics['valid'][name] = valid
+ a = valid[key]
+ b = bvalid[key]
+ if key.startswith('nsdr'):
+ a = -a
+ b = -b
+ if a < b:
+ bvalid = valid
+ state = ema.state
+ bname = name
+ metrics['valid'].update(bvalid)
+ metrics['valid']['bname'] = bname
+
+ valid_loss = metrics['valid'][key]
+ mets = pull_metric(self.link.history, f'valid.{key}') + [valid_loss]
+ if key.startswith('nsdr'):
+ best_loss = max(mets)
+ else:
+ best_loss = min(mets)
+ metrics['valid']['best'] = best_loss
+ if self.args.svd.penalty > 0:
+ kw = dict(self.args.svd)
+ kw.pop('penalty')
+ with torch.no_grad():
+ penalty = svd_penalty(self.model, exact=True, **kw)
+ metrics['valid']['penalty'] = penalty
+
+ formatted = self._format_train(metrics['valid'])
+ logger.info(
+ bold(f'Valid Summary | Epoch {epoch + 1} | {_summary(formatted)}'))
+
+ # Save the best model
+ if valid_loss == best_loss or self.args.dset.train_valid:
+ logger.info(bold('New best valid loss %.4f'), valid_loss)
+ self.best_state = states.copy_state(state)
+ self.best_changed = True
+
+ # Eval model every `test.every` epoch or on last epoch
+ should_eval = (epoch + 1) % self.args.test.every == 0
+ is_last = epoch == self.args.epochs - 1
+ # # Tries to detect divergence in a reliable way and finish job
+ # # not to waste compute.
+ # # Commented out as this was super specific to the MDX competition.
+ # reco = metrics['valid']['main']['reco']
+ # div = epoch >= 180 and reco > 0.18
+ # div = div or epoch >= 100 and reco > 0.25
+ # div = div and self.args.optim.loss == 'l1'
+ # if div:
+ # logger.warning("Finishing training early because valid loss is too high.")
+ # is_last = True
+ if should_eval or is_last:
+ # Evaluate on the testset
+ logger.info('-' * 70)
+ logger.info('Evaluating on the test set...')
+ # We switch to the best known model for testing
+ if self.args.test.best:
+ state = self.best_state
+ else:
+ state = states.copy_state(self.model.state_dict())
+ compute_sdr = self.args.test.sdr and is_last
+ with states.swap_state(self.model, state):
+ with torch.no_grad():
+ metrics['test'] = evaluate(self, compute_sdr=compute_sdr)
+ formatted = self._format_test(metrics['test'])
+ logger.info(bold(f"Test Summary | Epoch {epoch + 1} | {_summary(formatted)}"))
+ self.link.push_metrics(metrics)
+
+ if distrib.rank == 0:
+ # Save model each epoch
+ self._serialize(epoch)
+ logger.debug("Checkpoint saved to %s", self.checkpoint_file.resolve())
+ if is_last:
+ break
+
+ def _run_one_epoch(self, epoch, train=True):
+ args = self.args
+ data_loader = self.loaders['train'] if train else self.loaders['valid']
+ if distrib.world_size > 1 and train:
+ data_loader.sampler.set_epoch(epoch)
+
+ label = ["Valid", "Train"][train]
+ name = label + f" | Epoch {epoch + 1}"
+ total = len(data_loader)
+ if args.max_batches:
+ total = min(total, args.max_batches)
+ logprog = LogProgress(logger, data_loader, total=total,
+ updates=self.args.misc.num_prints, name=name)
+ averager = EMA()
+
+ for idx, sources in enumerate(logprog):
+ sources = sources.to(self.device)
+ if train:
+ sources = self.augment(sources)
+ mix = sources.sum(dim=1)
+ else:
+ mix = sources[:, 0]
+ sources = sources[:, 1:]
+
+ if not train and self.args.valid_apply:
+ estimate = apply_model(self.model, mix, split=self.args.test.split, overlap=0)
+ else:
+ estimate = self.dmodel(mix)
+ if train and hasattr(self.model, 'transform_target'):
+ sources = self.model.transform_target(mix, sources)
+ assert estimate.shape == sources.shape, (estimate.shape, sources.shape)
+ dims = tuple(range(2, sources.dim()))
+
+ if args.optim.loss == 'l1':
+ loss = F.l1_loss(estimate, sources, reduction='none')
+ loss = loss.mean(dims).mean(0)
+ reco = loss
+ elif args.optim.loss == 'mse':
+ loss = F.mse_loss(estimate, sources, reduction='none')
+ loss = loss.mean(dims)
+ reco = loss**0.5
+ reco = reco.mean(0)
+ else:
+ raise ValueError(f"Invalid loss {self.args.loss}")
+ weights = torch.tensor(args.weights).to(sources)
+ loss = (loss * weights).sum() / weights.sum()
+
+ ms = 0
+ if self.quantizer is not None:
+ ms = self.quantizer.model_size()
+ if args.quant.diffq:
+ loss += args.quant.diffq * ms
+
+ losses = {}
+ losses['reco'] = (reco * weights).sum() / weights.sum()
+ losses['ms'] = ms
+
+ if not train:
+ nsdrs = new_sdr(sources, estimate.detach()).mean(0)
+ total = 0
+ for source, nsdr, w in zip(self.model.sources, nsdrs, weights):
+ losses[f'nsdr_{source}'] = nsdr
+ total += w * nsdr
+ losses['nsdr'] = total / weights.sum()
+
+ if train and args.svd.penalty > 0:
+ kw = dict(args.svd)
+ kw.pop('penalty')
+ penalty = svd_penalty(self.model, **kw)
+ losses['penalty'] = penalty
+ loss += args.svd.penalty * penalty
+
+ losses['loss'] = loss
+
+ for k, source in enumerate(self.model.sources):
+ losses[f'reco_{source}'] = reco[k]
+
+ # optimize model in training mode
+ if train:
+ loss.backward()
+ grad_norm = 0
+ grads = []
+ for p in self.model.parameters():
+ if p.grad is not None:
+ grad_norm += p.grad.data.norm()**2
+ grads.append(p.grad.data)
+ losses['grad'] = grad_norm ** 0.5
+ if args.optim.clip_grad:
+ torch.nn.utils.clip_grad_norm_(
+ self.model.parameters(),
+ args.optim.clip_grad)
+
+ if self.args.flag == 'uns':
+ for n, p in self.model.named_parameters():
+ if p.grad is None:
+ print('no grad', n)
+ self.optimizer.step()
+ self.optimizer.zero_grad()
+ for ema in self.emas['batch']:
+ ema.update()
+ losses = averager(losses)
+ logs = self._format_train(losses)
+ logprog.update(**logs)
+ # Just in case, clear some memory
+ del loss, estimate, reco, ms
+ if args.max_batches == idx:
+ break
+ if self.args.debug and train:
+ break
+ if self.args.flag == 'debug':
+ break
+ if train:
+ for ema in self.emas['epoch']:
+ ema.update()
+ return distrib.average(losses, idx + 1)
diff --git a/AIMeiSheng/demucs/spec.py b/AIMeiSheng/demucs/spec.py
new file mode 100644
index 0000000..2925045
--- /dev/null
+++ b/AIMeiSheng/demucs/spec.py
@@ -0,0 +1,47 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Conveniance wrapper to perform STFT and iSTFT"""
+
+import torch as th
+
+
+def spectro(x, n_fft=512, hop_length=None, pad=0):
+ *other, length = x.shape
+ x = x.reshape(-1, length)
+ is_mps = x.device.type == 'mps'
+ if is_mps:
+ x = x.cpu()
+ z = th.stft(x,
+ n_fft * (1 + pad),
+ hop_length or n_fft // 4,
+ window=th.hann_window(n_fft).to(x),
+ win_length=n_fft,
+ normalized=True,
+ center=True,
+ return_complex=True,
+ pad_mode='reflect')
+ _, freqs, frame = z.shape
+ return z.view(*other, freqs, frame)
+
+
+def ispectro(z, hop_length=None, length=None, pad=0):
+ *other, freqs, frames = z.shape
+ n_fft = 2 * freqs - 2
+ z = z.view(-1, freqs, frames)
+ win_length = n_fft // (1 + pad)
+ is_mps = z.device.type == 'mps'
+ if is_mps:
+ z = z.cpu()
+ x = th.istft(z,
+ n_fft,
+ hop_length,
+ window=th.hann_window(win_length).to(z.real),
+ win_length=win_length,
+ normalized=True,
+ length=length,
+ center=True)
+ _, length = x.shape
+ return x.view(*other, length)
diff --git a/AIMeiSheng/demucs/states.py b/AIMeiSheng/demucs/states.py
new file mode 100644
index 0000000..361bb41
--- /dev/null
+++ b/AIMeiSheng/demucs/states.py
@@ -0,0 +1,163 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""
+Utilities to save and load models.
+"""
+from contextlib import contextmanager
+
+import functools
+import hashlib
+import inspect
+import io
+from pathlib import Path
+import warnings
+
+from omegaconf import OmegaConf
+from dora.log import fatal
+import torch
+
+
+def _check_diffq():
+ try:
+ import diffq # noqa
+ except ImportError:
+ fatal('Trying to use DiffQ, but diffq is not installed.\n'
+ 'On Windows run: python.exe -m pip install diffq \n'
+ 'On Linux/Mac, run: python3 -m pip install diffq')
+
+
+def get_quantizer(model, args, optimizer=None):
+ """Return the quantizer given the XP quantization args."""
+ quantizer = None
+ if args.diffq:
+ _check_diffq()
+ from diffq import DiffQuantizer
+ quantizer = DiffQuantizer(
+ model, min_size=args.min_size, group_size=args.group_size)
+ if optimizer is not None:
+ quantizer.setup_optimizer(optimizer)
+ elif args.qat:
+ _check_diffq()
+ from diffq import UniformQuantizer
+ quantizer = UniformQuantizer(
+ model, bits=args.qat, min_size=args.min_size)
+ return quantizer
+
+
+def load_model(path_or_package, strict=False):
+ """Load a model from the given serialized model, either given as a dict (already loaded)
+ or a path to a file on disk."""
+ if isinstance(path_or_package, dict):
+ package = path_or_package
+ elif isinstance(path_or_package, (str, Path)):
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ path = path_or_package
+ package = torch.load(path, 'cpu')
+ else:
+ raise ValueError(f"Invalid type for {path_or_package}.")
+
+ klass = package["klass"]
+ args = package["args"]
+ kwargs = package["kwargs"]
+
+ if strict:
+ model = klass(*args, **kwargs)
+ else:
+ sig = inspect.signature(klass)
+ for key in list(kwargs):
+ if key not in sig.parameters:
+ warnings.warn("Dropping inexistant parameter " + key)
+ del kwargs[key]
+ model = klass(*args, **kwargs)
+
+ state = package["state"]
+
+ set_state(model, state)
+ return model
+
+
+def get_state(model, quantizer, half=False):
+ """Get the state from a model, potentially with quantization applied.
+ If `half` is True, model are stored as half precision, which shouldn't impact performance
+ but half the state size."""
+ if quantizer is None:
+ dtype = torch.half if half else None
+ state = {k: p.data.to(device='cpu', dtype=dtype) for k, p in model.state_dict().items()}
+ else:
+ state = quantizer.get_quantized_state()
+ state['__quantized'] = True
+ return state
+
+
+def set_state(model, state, quantizer=None):
+ """Set the state on a given model."""
+ if state.get('__quantized'):
+ if quantizer is not None:
+ quantizer.restore_quantized_state(model, state['quantized'])
+ else:
+ _check_diffq()
+ from diffq import restore_quantized_state
+ restore_quantized_state(model, state)
+ else:
+ model.load_state_dict(state)
+ return state
+
+
+def save_with_checksum(content, path):
+ """Save the given value on disk, along with a sha256 hash.
+ Should be used with the output of either `serialize_model` or `get_state`."""
+ buf = io.BytesIO()
+ torch.save(content, buf)
+ sig = hashlib.sha256(buf.getvalue()).hexdigest()[:8]
+
+ path = path.parent / (path.stem + "-" + sig + path.suffix)
+ path.write_bytes(buf.getvalue())
+
+
+def serialize_model(model, training_args, quantizer=None, half=True):
+ args, kwargs = model._init_args_kwargs
+ klass = model.__class__
+
+ state = get_state(model, quantizer, half)
+ return {
+ 'klass': klass,
+ 'args': args,
+ 'kwargs': kwargs,
+ 'state': state,
+ 'training_args': OmegaConf.to_container(training_args, resolve=True),
+ }
+
+
+def copy_state(state):
+ return {k: v.cpu().clone() for k, v in state.items()}
+
+
+@contextmanager
+def swap_state(model, state):
+ """
+ Context manager that swaps the state of a model, e.g:
+
+ # model is in old state
+ with swap_state(model, new_state):
+ # model in new state
+ # model back to old state
+ """
+ old_state = copy_state(model.state_dict())
+ model.load_state_dict(state, strict=False)
+ try:
+ yield
+ finally:
+ model.load_state_dict(old_state)
+
+
+def capture_init(init):
+ @functools.wraps(init)
+ def __init__(self, *args, **kwargs):
+ self._init_args_kwargs = (args, kwargs)
+ init(self, *args, **kwargs)
+
+ return __init__
diff --git a/AIMeiSheng/demucs/svd.py b/AIMeiSheng/demucs/svd.py
new file mode 100644
index 0000000..1cbaa82
--- /dev/null
+++ b/AIMeiSheng/demucs/svd.py
@@ -0,0 +1,83 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Ways to make the model stronger."""
+import random
+import torch
+
+
+def power_iteration(m, niters=1, bs=1):
+ """This is the power method. batch size is used to try multiple starting point in parallel."""
+ assert m.dim() == 2
+ assert m.shape[0] == m.shape[1]
+ dim = m.shape[0]
+ b = torch.randn(dim, bs, device=m.device, dtype=m.dtype)
+
+ for _ in range(niters):
+ n = m.mm(b)
+ norm = n.norm(dim=0, keepdim=True)
+ b = n / (1e-10 + norm)
+
+ return norm.mean()
+
+
+# We need a shared RNG to make sure all the distributed worker will skip the penalty together,
+# as otherwise we wouldn't get any speed up.
+penalty_rng = random.Random(1234)
+
+
+def svd_penalty(model, min_size=0.1, dim=1, niters=2, powm=False, convtr=True,
+ proba=1, conv_only=False, exact=False, bs=1):
+ """
+ Penalty on the largest singular value for a layer.
+ Args:
+ - model: model to penalize
+ - min_size: minimum size in MB of a layer to penalize.
+ - dim: projection dimension for the svd_lowrank. Higher is better but slower.
+ - niters: number of iterations in the algorithm used by svd_lowrank.
+ - powm: use power method instead of lowrank SVD, my own experience
+ is that it is both slower and less stable.
+ - convtr: when True, differentiate between Conv and Transposed Conv.
+ this is kept for compatibility with older experiments.
+ - proba: probability to apply the penalty.
+ - conv_only: only apply to conv and conv transposed, not LSTM
+ (might not be reliable for other models than Demucs).
+ - exact: use exact SVD (slow but useful at validation).
+ - bs: batch_size for power method.
+ """
+ total = 0
+ if penalty_rng.random() > proba:
+ return 0.
+
+ for m in model.modules():
+ for name, p in m.named_parameters(recurse=False):
+ if p.numel() / 2**18 < min_size:
+ continue
+ if convtr:
+ if isinstance(m, (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d)):
+ if p.dim() in [3, 4]:
+ p = p.transpose(0, 1).contiguous()
+ if p.dim() == 3:
+ p = p.view(len(p), -1)
+ elif p.dim() == 4:
+ p = p.view(len(p), -1)
+ elif p.dim() == 1:
+ continue
+ elif conv_only:
+ continue
+ assert p.dim() == 2, (name, p.shape)
+ if exact:
+ estimate = torch.svd(p, compute_uv=False)[1].pow(2).max()
+ elif powm:
+ a, b = p.shape
+ if a < b:
+ n = p.mm(p.t())
+ else:
+ n = p.t().mm(p)
+ estimate = power_iteration(n, niters, bs)
+ else:
+ estimate = torch.svd_lowrank(p, dim, niters)[1][0].pow(2)
+ total += estimate
+ return total / proba
diff --git a/AIMeiSheng/demucs/train.py b/AIMeiSheng/demucs/train.py
new file mode 100644
index 0000000..9aa7b64
--- /dev/null
+++ b/AIMeiSheng/demucs/train.py
@@ -0,0 +1,251 @@
+#!/usr/bin/env python3
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Main training script entry point"""
+
+import logging
+import os
+from pathlib import Path
+import sys
+
+from dora import hydra_main
+import hydra
+from hydra.core.global_hydra import GlobalHydra
+from omegaconf import OmegaConf
+import torch
+from torch import nn
+import torchaudio
+from torch.utils.data import ConcatDataset
+
+from . import distrib
+from .wav import get_wav_datasets, get_musdb_wav_datasets
+from .demucs import Demucs
+from .hdemucs import HDemucs
+from .htdemucs import HTDemucs
+from .repitch import RepitchedWrapper
+from .solver import Solver
+from .states import capture_init
+from .utils import random_subset
+
+logger = logging.getLogger(__name__)
+
+
+class TorchHDemucsWrapper(nn.Module):
+ """Wrapper around torchaudio HDemucs implementation to provide the proper metadata
+ for model evaluation.
+ See https://pytorch.org/audio/stable/tutorials/hybrid_demucs_tutorial.html"""
+
+ @capture_init
+ def __init__(self, **kwargs):
+ super().__init__()
+ try:
+ from torchaudio.models import HDemucs as TorchHDemucs
+ except ImportError:
+ raise ImportError("Please upgrade torchaudio for using its implementation of HDemucs")
+ self.samplerate = kwargs.pop('samplerate')
+ self.segment = kwargs.pop('segment')
+ self.sources = kwargs['sources']
+ self.torch_hdemucs = TorchHDemucs(**kwargs)
+
+ def forward(self, mix):
+ return self.torch_hdemucs.forward(mix)
+
+
+def get_model(args):
+ extra = {
+ 'sources': list(args.dset.sources),
+ 'audio_channels': args.dset.channels,
+ 'samplerate': args.dset.samplerate,
+ 'segment': args.model_segment or 4 * args.dset.segment,
+ }
+ klass = {
+ 'demucs': Demucs,
+ 'hdemucs': HDemucs,
+ 'htdemucs': HTDemucs,
+ 'torch_hdemucs': TorchHDemucsWrapper,
+ }[args.model]
+ kw = OmegaConf.to_container(getattr(args, args.model), resolve=True)
+ model = klass(**extra, **kw)
+ return model
+
+
+def get_optimizer(model, args):
+ seen_params = set()
+ other_params = []
+ groups = []
+ for n, module in model.named_modules():
+ if hasattr(module, "make_optim_group"):
+ group = module.make_optim_group()
+ params = set(group["params"])
+ assert params.isdisjoint(seen_params)
+ seen_params |= set(params)
+ groups.append(group)
+ for param in model.parameters():
+ if param not in seen_params:
+ other_params.append(param)
+ groups.insert(0, {"params": other_params})
+ parameters = groups
+ if args.optim.optim == "adam":
+ return torch.optim.Adam(
+ parameters,
+ lr=args.optim.lr,
+ betas=(args.optim.momentum, args.optim.beta2),
+ weight_decay=args.optim.weight_decay,
+ )
+ elif args.optim.optim == "adamw":
+ return torch.optim.AdamW(
+ parameters,
+ lr=args.optim.lr,
+ betas=(args.optim.momentum, args.optim.beta2),
+ weight_decay=args.optim.weight_decay,
+ )
+ else:
+ raise ValueError("Invalid optimizer %s", args.optim.optimizer)
+
+
+def get_datasets(args):
+ if args.dset.backend:
+ torchaudio.set_audio_backend(args.dset.backend)
+ if args.dset.use_musdb:
+ train_set, valid_set = get_musdb_wav_datasets(args.dset)
+ else:
+ train_set, valid_set = [], []
+ if args.dset.wav:
+ extra_train_set, extra_valid_set = get_wav_datasets(args.dset)
+ if len(args.dset.sources) <= 4:
+ train_set = ConcatDataset([train_set, extra_train_set])
+ valid_set = ConcatDataset([valid_set, extra_valid_set])
+ else:
+ train_set = extra_train_set
+ valid_set = extra_valid_set
+
+ if args.dset.wav2:
+ extra_train_set, extra_valid_set = get_wav_datasets(args.dset, "wav2")
+ weight = args.dset.wav2_weight
+ if weight is not None:
+ b = len(train_set)
+ e = len(extra_train_set)
+ reps = max(1, round(e / b * (1 / weight - 1)))
+ else:
+ reps = 1
+ train_set = ConcatDataset([train_set] * reps + [extra_train_set])
+ if args.dset.wav2_valid:
+ if weight is not None:
+ b = len(valid_set)
+ n_kept = int(round(weight * b / (1 - weight)))
+ valid_set = ConcatDataset(
+ [valid_set, random_subset(extra_valid_set, n_kept)]
+ )
+ else:
+ valid_set = ConcatDataset([valid_set, extra_valid_set])
+ if args.dset.valid_samples is not None:
+ valid_set = random_subset(valid_set, args.dset.valid_samples)
+ assert len(train_set)
+ assert len(valid_set)
+ return train_set, valid_set
+
+
+def get_solver(args, model_only=False):
+ distrib.init()
+
+ torch.manual_seed(args.seed)
+ model = get_model(args)
+ if args.misc.show:
+ logger.info(model)
+ mb = sum(p.numel() for p in model.parameters()) * 4 / 2**20
+ logger.info('Size: %.1f MB', mb)
+ if hasattr(model, 'valid_length'):
+ field = model.valid_length(1)
+ logger.info('Field: %.1f ms', field / args.dset.samplerate * 1000)
+ sys.exit(0)
+
+ # torch also initialize cuda seed if available
+ if torch.cuda.is_available():
+ model.cuda()
+
+ # optimizer
+ optimizer = get_optimizer(model, args)
+
+ assert args.batch_size % distrib.world_size == 0
+ args.batch_size //= distrib.world_size
+
+ if model_only:
+ return Solver(None, model, optimizer, args)
+
+ train_set, valid_set = get_datasets(args)
+
+ if args.augment.repitch.proba:
+ vocals = []
+ if 'vocals' in args.dset.sources:
+ vocals.append(args.dset.sources.index('vocals'))
+ else:
+ logger.warning('No vocal source found')
+ if args.augment.repitch.proba:
+ train_set = RepitchedWrapper(train_set, vocals=vocals, **args.augment.repitch)
+
+ logger.info("train/valid set size: %d %d", len(train_set), len(valid_set))
+ train_loader = distrib.loader(
+ train_set, batch_size=args.batch_size, shuffle=True,
+ num_workers=args.misc.num_workers, drop_last=True)
+ if args.dset.full_cv:
+ valid_loader = distrib.loader(
+ valid_set, batch_size=1, shuffle=False,
+ num_workers=args.misc.num_workers)
+ else:
+ valid_loader = distrib.loader(
+ valid_set, batch_size=args.batch_size, shuffle=False,
+ num_workers=args.misc.num_workers, drop_last=True)
+ loaders = {"train": train_loader, "valid": valid_loader}
+
+ # Construct Solver
+ return Solver(loaders, model, optimizer, args)
+
+
+def get_solver_from_sig(sig, model_only=False):
+ inst = GlobalHydra.instance()
+ hyd = None
+ if inst.is_initialized():
+ hyd = inst.hydra
+ inst.clear()
+ xp = main.get_xp_from_sig(sig)
+ if hyd is not None:
+ inst.clear()
+ inst.initialize(hyd)
+
+ with xp.enter(stack=True):
+ return get_solver(xp.cfg, model_only)
+
+
+@hydra_main(config_path="../conf", config_name="config", version_base="1.1")
+def main(args):
+ global __file__
+ __file__ = hydra.utils.to_absolute_path(__file__)
+ for attr in ["musdb", "wav", "metadata"]:
+ val = getattr(args.dset, attr)
+ if val is not None:
+ setattr(args.dset, attr, hydra.utils.to_absolute_path(val))
+
+ os.environ["OMP_NUM_THREADS"] = "1"
+ os.environ["MKL_NUM_THREADS"] = "1"
+
+ if args.misc.verbose:
+ logger.setLevel(logging.DEBUG)
+
+ logger.info("For logs, checkpoints and samples check %s", os.getcwd())
+ logger.debug(args)
+ from dora import get_xp
+ logger.debug(get_xp().cfg)
+
+ solver = get_solver(args)
+ solver.train()
+
+
+if '_DORA_TEST_PATH' in os.environ:
+ main.dora.dir = Path(os.environ['_DORA_TEST_PATH'])
+
+
+if __name__ == "__main__":
+ main()
diff --git a/AIMeiSheng/demucs/transformer.py b/AIMeiSheng/demucs/transformer.py
new file mode 100644
index 0000000..56a465b
--- /dev/null
+++ b/AIMeiSheng/demucs/transformer.py
@@ -0,0 +1,839 @@
+# Copyright (c) 2019-present, Meta, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# First author is Simon Rouard.
+
+import random
+import typing as tp
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+import math
+from einops import rearrange
+
+
+def create_sin_embedding(
+ length: int, dim: int, shift: int = 0, device="cpu", max_period=10000
+):
+ # We aim for TBC format
+ assert dim % 2 == 0
+ pos = shift + torch.arange(length, device=device).view(-1, 1, 1)
+ half_dim = dim // 2
+ adim = torch.arange(dim // 2, device=device).view(1, 1, -1)
+ phase = pos / (max_period ** (adim / (half_dim - 1)))
+ return torch.cat(
+ [
+ torch.cos(phase),
+ torch.sin(phase),
+ ],
+ dim=-1,
+ )
+
+
+def create_2d_sin_embedding(d_model, height, width, device="cpu", max_period=10000):
+ """
+ :param d_model: dimension of the model
+ :param height: height of the positions
+ :param width: width of the positions
+ :return: d_model*height*width position matrix
+ """
+ if d_model % 4 != 0:
+ raise ValueError(
+ "Cannot use sin/cos positional encoding with "
+ "odd dimension (got dim={:d})".format(d_model)
+ )
+ pe = torch.zeros(d_model, height, width)
+ # Each dimension use half of d_model
+ d_model = int(d_model / 2)
+ div_term = torch.exp(
+ torch.arange(0.0, d_model, 2) * -(math.log(max_period) / d_model)
+ )
+ pos_w = torch.arange(0.0, width).unsqueeze(1)
+ pos_h = torch.arange(0.0, height).unsqueeze(1)
+ pe[0:d_model:2, :, :] = (
+ torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
+ )
+ pe[1:d_model:2, :, :] = (
+ torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
+ )
+ pe[d_model::2, :, :] = (
+ torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
+ )
+ pe[d_model + 1:: 2, :, :] = (
+ torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
+ )
+
+ return pe[None, :].to(device)
+
+
+def create_sin_embedding_cape(
+ length: int,
+ dim: int,
+ batch_size: int,
+ mean_normalize: bool,
+ augment: bool, # True during training
+ max_global_shift: float = 0.0, # delta max
+ max_local_shift: float = 0.0, # epsilon max
+ max_scale: float = 1.0,
+ device: str = "cpu",
+ max_period: float = 10000.0,
+):
+ # We aim for TBC format
+ assert dim % 2 == 0
+ pos = 1.0 * torch.arange(length).view(-1, 1, 1) # (length, 1, 1)
+ pos = pos.repeat(1, batch_size, 1) # (length, batch_size, 1)
+ if mean_normalize:
+ pos -= torch.nanmean(pos, dim=0, keepdim=True)
+
+ if augment:
+ delta = np.random.uniform(
+ -max_global_shift, +max_global_shift, size=[1, batch_size, 1]
+ )
+ delta_local = np.random.uniform(
+ -max_local_shift, +max_local_shift, size=[length, batch_size, 1]
+ )
+ log_lambdas = np.random.uniform(
+ -np.log(max_scale), +np.log(max_scale), size=[1, batch_size, 1]
+ )
+ pos = (pos + delta + delta_local) * np.exp(log_lambdas)
+
+ pos = pos.to(device)
+
+ half_dim = dim // 2
+ adim = torch.arange(dim // 2, device=device).view(1, 1, -1)
+ phase = pos / (max_period ** (adim / (half_dim - 1)))
+ return torch.cat(
+ [
+ torch.cos(phase),
+ torch.sin(phase),
+ ],
+ dim=-1,
+ ).float()
+
+
+def get_causal_mask(length):
+ pos = torch.arange(length)
+ return pos > pos[:, None]
+
+
+def get_elementary_mask(
+ T1,
+ T2,
+ mask_type,
+ sparse_attn_window,
+ global_window,
+ mask_random_seed,
+ sparsity,
+ device,
+):
+ """
+ When the input of the Decoder has length T1 and the output T2
+ The mask matrix has shape (T2, T1)
+ """
+ assert mask_type in ["diag", "jmask", "random", "global"]
+
+ if mask_type == "global":
+ mask = torch.zeros(T2, T1, dtype=torch.bool)
+ mask[:, :global_window] = True
+ line_window = int(global_window * T2 / T1)
+ mask[:line_window, :] = True
+
+ if mask_type == "diag":
+
+ mask = torch.zeros(T2, T1, dtype=torch.bool)
+ rows = torch.arange(T2)[:, None]
+ cols = (
+ (T1 / T2 * rows + torch.arange(-sparse_attn_window, sparse_attn_window + 1))
+ .long()
+ .clamp(0, T1 - 1)
+ )
+ mask.scatter_(1, cols, torch.ones(1, dtype=torch.bool).expand_as(cols))
+
+ elif mask_type == "jmask":
+ mask = torch.zeros(T2 + 2, T1 + 2, dtype=torch.bool)
+ rows = torch.arange(T2 + 2)[:, None]
+ t = torch.arange(0, int((2 * T1) ** 0.5 + 1))
+ t = (t * (t + 1) / 2).int()
+ t = torch.cat([-t.flip(0)[:-1], t])
+ cols = (T1 / T2 * rows + t).long().clamp(0, T1 + 1)
+ mask.scatter_(1, cols, torch.ones(1, dtype=torch.bool).expand_as(cols))
+ mask = mask[1:-1, 1:-1]
+
+ elif mask_type == "random":
+ gene = torch.Generator(device=device)
+ gene.manual_seed(mask_random_seed)
+ mask = (
+ torch.rand(T1 * T2, generator=gene, device=device).reshape(T2, T1)
+ > sparsity
+ )
+
+ mask = mask.to(device)
+ return mask
+
+
+def get_mask(
+ T1,
+ T2,
+ mask_type,
+ sparse_attn_window,
+ global_window,
+ mask_random_seed,
+ sparsity,
+ device,
+):
+ """
+ Return a SparseCSRTensor mask that is a combination of elementary masks
+ mask_type can be a combination of multiple masks: for instance "diag_jmask_random"
+ """
+ from xformers.sparse import SparseCSRTensor
+ # create a list
+ mask_types = mask_type.split("_")
+
+ all_masks = [
+ get_elementary_mask(
+ T1,
+ T2,
+ mask,
+ sparse_attn_window,
+ global_window,
+ mask_random_seed,
+ sparsity,
+ device,
+ )
+ for mask in mask_types
+ ]
+
+ final_mask = torch.stack(all_masks).sum(axis=0) > 0
+
+ return SparseCSRTensor.from_dense(final_mask[None])
+
+
+class ScaledEmbedding(nn.Module):
+ def __init__(
+ self,
+ num_embeddings: int,
+ embedding_dim: int,
+ scale: float = 1.0,
+ boost: float = 3.0,
+ ):
+ super().__init__()
+ self.embedding = nn.Embedding(num_embeddings, embedding_dim)
+ self.embedding.weight.data *= scale / boost
+ self.boost = boost
+
+ @property
+ def weight(self):
+ return self.embedding.weight * self.boost
+
+ def forward(self, x):
+ return self.embedding(x) * self.boost
+
+
+class LayerScale(nn.Module):
+ """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
+ This rescales diagonaly residual outputs close to 0 initially, then learnt.
+ """
+
+ def __init__(self, channels: int, init: float = 0, channel_last=False):
+ """
+ channel_last = False corresponds to (B, C, T) tensors
+ channel_last = True corresponds to (T, B, C) tensors
+ """
+ super().__init__()
+ self.channel_last = channel_last
+ self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
+ self.scale.data[:] = init
+
+ def forward(self, x):
+ if self.channel_last:
+ return self.scale * x
+ else:
+ return self.scale[:, None] * x
+
+
+class MyGroupNorm(nn.GroupNorm):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(self, x):
+ """
+ x: (B, T, C)
+ if num_groups=1: Normalisation on all T and C together for each B
+ """
+ x = x.transpose(1, 2)
+ return super().forward(x).transpose(1, 2)
+
+
+class MyTransformerEncoderLayer(nn.TransformerEncoderLayer):
+ def __init__(
+ self,
+ d_model,
+ nhead,
+ dim_feedforward=2048,
+ dropout=0.1,
+ activation=F.relu,
+ group_norm=0,
+ norm_first=False,
+ norm_out=False,
+ layer_norm_eps=1e-5,
+ layer_scale=False,
+ init_values=1e-4,
+ device=None,
+ dtype=None,
+ sparse=False,
+ mask_type="diag",
+ mask_random_seed=42,
+ sparse_attn_window=500,
+ global_window=50,
+ auto_sparsity=False,
+ sparsity=0.95,
+ batch_first=False,
+ ):
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__(
+ d_model=d_model,
+ nhead=nhead,
+ dim_feedforward=dim_feedforward,
+ dropout=dropout,
+ activation=activation,
+ layer_norm_eps=layer_norm_eps,
+ batch_first=batch_first,
+ norm_first=norm_first,
+ device=device,
+ dtype=dtype,
+ )
+ self.sparse = sparse
+ self.auto_sparsity = auto_sparsity
+ if sparse:
+ if not auto_sparsity:
+ self.mask_type = mask_type
+ self.sparse_attn_window = sparse_attn_window
+ self.global_window = global_window
+ self.sparsity = sparsity
+ if group_norm:
+ self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
+ self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
+
+ self.norm_out = None
+ if self.norm_first & norm_out:
+ self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model)
+ self.gamma_1 = (
+ LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
+ )
+ self.gamma_2 = (
+ LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
+ )
+
+ if sparse:
+ self.self_attn = MultiheadAttention(
+ d_model, nhead, dropout=dropout, batch_first=batch_first,
+ auto_sparsity=sparsity if auto_sparsity else 0,
+ )
+ self.__setattr__("src_mask", torch.zeros(1, 1))
+ self.mask_random_seed = mask_random_seed
+
+ def forward(self, src, src_mask=None, src_key_padding_mask=None):
+ """
+ if batch_first = False, src shape is (T, B, C)
+ the case where batch_first=True is not covered
+ """
+ device = src.device
+ x = src
+ T, B, C = x.shape
+ if self.sparse and not self.auto_sparsity:
+ assert src_mask is None
+ src_mask = self.src_mask
+ if src_mask.shape[-1] != T:
+ src_mask = get_mask(
+ T,
+ T,
+ self.mask_type,
+ self.sparse_attn_window,
+ self.global_window,
+ self.mask_random_seed,
+ self.sparsity,
+ device,
+ )
+ self.__setattr__("src_mask", src_mask)
+
+ if self.norm_first:
+ x = x + self.gamma_1(
+ self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
+ )
+ x = x + self.gamma_2(self._ff_block(self.norm2(x)))
+
+ if self.norm_out:
+ x = self.norm_out(x)
+ else:
+ x = self.norm1(
+ x + self.gamma_1(self._sa_block(x, src_mask, src_key_padding_mask))
+ )
+ x = self.norm2(x + self.gamma_2(self._ff_block(x)))
+
+ return x
+
+
+class CrossTransformerEncoderLayer(nn.Module):
+ def __init__(
+ self,
+ d_model: int,
+ nhead: int,
+ dim_feedforward: int = 2048,
+ dropout: float = 0.1,
+ activation=F.relu,
+ layer_norm_eps: float = 1e-5,
+ layer_scale: bool = False,
+ init_values: float = 1e-4,
+ norm_first: bool = False,
+ group_norm: bool = False,
+ norm_out: bool = False,
+ sparse=False,
+ mask_type="diag",
+ mask_random_seed=42,
+ sparse_attn_window=500,
+ global_window=50,
+ sparsity=0.95,
+ auto_sparsity=None,
+ device=None,
+ dtype=None,
+ batch_first=False,
+ ):
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__()
+
+ self.sparse = sparse
+ self.auto_sparsity = auto_sparsity
+ if sparse:
+ if not auto_sparsity:
+ self.mask_type = mask_type
+ self.sparse_attn_window = sparse_attn_window
+ self.global_window = global_window
+ self.sparsity = sparsity
+
+ self.cross_attn: nn.Module
+ self.cross_attn = nn.MultiheadAttention(
+ d_model, nhead, dropout=dropout, batch_first=batch_first)
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
+
+ self.norm_first = norm_first
+ self.norm1: nn.Module
+ self.norm2: nn.Module
+ self.norm3: nn.Module
+ if group_norm:
+ self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
+ self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
+ self.norm3 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
+ else:
+ self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
+ self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
+ self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
+
+ self.norm_out = None
+ if self.norm_first & norm_out:
+ self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model)
+
+ self.gamma_1 = (
+ LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
+ )
+ self.gamma_2 = (
+ LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
+ )
+
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+
+ # Legacy string support for activation function.
+ if isinstance(activation, str):
+ self.activation = self._get_activation_fn(activation)
+ else:
+ self.activation = activation
+
+ if sparse:
+ self.cross_attn = MultiheadAttention(
+ d_model, nhead, dropout=dropout, batch_first=batch_first,
+ auto_sparsity=sparsity if auto_sparsity else 0)
+ if not auto_sparsity:
+ self.__setattr__("mask", torch.zeros(1, 1))
+ self.mask_random_seed = mask_random_seed
+
+ def forward(self, q, k, mask=None):
+ """
+ Args:
+ q: tensor of shape (T, B, C)
+ k: tensor of shape (S, B, C)
+ mask: tensor of shape (T, S)
+
+ """
+ device = q.device
+ T, B, C = q.shape
+ S, B, C = k.shape
+ if self.sparse and not self.auto_sparsity:
+ assert mask is None
+ mask = self.mask
+ if mask.shape[-1] != S or mask.shape[-2] != T:
+ mask = get_mask(
+ S,
+ T,
+ self.mask_type,
+ self.sparse_attn_window,
+ self.global_window,
+ self.mask_random_seed,
+ self.sparsity,
+ device,
+ )
+ self.__setattr__("mask", mask)
+
+ if self.norm_first:
+ x = q + self.gamma_1(self._ca_block(self.norm1(q), self.norm2(k), mask))
+ x = x + self.gamma_2(self._ff_block(self.norm3(x)))
+ if self.norm_out:
+ x = self.norm_out(x)
+ else:
+ x = self.norm1(q + self.gamma_1(self._ca_block(q, k, mask)))
+ x = self.norm2(x + self.gamma_2(self._ff_block(x)))
+
+ return x
+
+ # self-attention block
+ def _ca_block(self, q, k, attn_mask=None):
+ x = self.cross_attn(q, k, k, attn_mask=attn_mask, need_weights=False)[0]
+ return self.dropout1(x)
+
+ # feed forward block
+ def _ff_block(self, x):
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
+ return self.dropout2(x)
+
+ def _get_activation_fn(self, activation):
+ if activation == "relu":
+ return F.relu
+ elif activation == "gelu":
+ return F.gelu
+
+ raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
+
+
+# ----------------- MULTI-BLOCKS MODELS: -----------------------
+
+
+class CrossTransformerEncoder(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ emb: str = "sin",
+ hidden_scale: float = 4.0,
+ num_heads: int = 8,
+ num_layers: int = 6,
+ cross_first: bool = False,
+ dropout: float = 0.0,
+ max_positions: int = 1000,
+ norm_in: bool = True,
+ norm_in_group: bool = False,
+ group_norm: int = False,
+ norm_first: bool = False,
+ norm_out: bool = False,
+ max_period: float = 10000.0,
+ weight_decay: float = 0.0,
+ lr: tp.Optional[float] = None,
+ layer_scale: bool = False,
+ gelu: bool = True,
+ sin_random_shift: int = 0,
+ weight_pos_embed: float = 1.0,
+ cape_mean_normalize: bool = True,
+ cape_augment: bool = True,
+ cape_glob_loc_scale: list = [5000.0, 1.0, 1.4],
+ sparse_self_attn: bool = False,
+ sparse_cross_attn: bool = False,
+ mask_type: str = "diag",
+ mask_random_seed: int = 42,
+ sparse_attn_window: int = 500,
+ global_window: int = 50,
+ auto_sparsity: bool = False,
+ sparsity: float = 0.95,
+ ):
+ super().__init__()
+ """
+ """
+ assert dim % num_heads == 0
+
+ hidden_dim = int(dim * hidden_scale)
+
+ self.num_layers = num_layers
+ # classic parity = 1 means that if idx%2 == 1 there is a
+ # classical encoder else there is a cross encoder
+ self.classic_parity = 1 if cross_first else 0
+ self.emb = emb
+ self.max_period = max_period
+ self.weight_decay = weight_decay
+ self.weight_pos_embed = weight_pos_embed
+ self.sin_random_shift = sin_random_shift
+ if emb == "cape":
+ self.cape_mean_normalize = cape_mean_normalize
+ self.cape_augment = cape_augment
+ self.cape_glob_loc_scale = cape_glob_loc_scale
+ if emb == "scaled":
+ self.position_embeddings = ScaledEmbedding(max_positions, dim, scale=0.2)
+
+ self.lr = lr
+
+ activation: tp.Any = F.gelu if gelu else F.relu
+
+ self.norm_in: nn.Module
+ self.norm_in_t: nn.Module
+ if norm_in:
+ self.norm_in = nn.LayerNorm(dim)
+ self.norm_in_t = nn.LayerNorm(dim)
+ elif norm_in_group:
+ self.norm_in = MyGroupNorm(int(norm_in_group), dim)
+ self.norm_in_t = MyGroupNorm(int(norm_in_group), dim)
+ else:
+ self.norm_in = nn.Identity()
+ self.norm_in_t = nn.Identity()
+
+ # spectrogram layers
+ self.layers = nn.ModuleList()
+ # temporal layers
+ self.layers_t = nn.ModuleList()
+
+ kwargs_common = {
+ "d_model": dim,
+ "nhead": num_heads,
+ "dim_feedforward": hidden_dim,
+ "dropout": dropout,
+ "activation": activation,
+ "group_norm": group_norm,
+ "norm_first": norm_first,
+ "norm_out": norm_out,
+ "layer_scale": layer_scale,
+ "mask_type": mask_type,
+ "mask_random_seed": mask_random_seed,
+ "sparse_attn_window": sparse_attn_window,
+ "global_window": global_window,
+ "sparsity": sparsity,
+ "auto_sparsity": auto_sparsity,
+ "batch_first": True,
+ }
+
+ kwargs_classic_encoder = dict(kwargs_common)
+ kwargs_classic_encoder.update({
+ "sparse": sparse_self_attn,
+ })
+ kwargs_cross_encoder = dict(kwargs_common)
+ kwargs_cross_encoder.update({
+ "sparse": sparse_cross_attn,
+ })
+
+ for idx in range(num_layers):
+ if idx % 2 == self.classic_parity:
+
+ self.layers.append(MyTransformerEncoderLayer(**kwargs_classic_encoder))
+ self.layers_t.append(
+ MyTransformerEncoderLayer(**kwargs_classic_encoder)
+ )
+
+ else:
+ self.layers.append(CrossTransformerEncoderLayer(**kwargs_cross_encoder))
+
+ self.layers_t.append(
+ CrossTransformerEncoderLayer(**kwargs_cross_encoder)
+ )
+
+ def forward(self, x, xt):
+ B, C, Fr, T1 = x.shape
+ pos_emb_2d = create_2d_sin_embedding(
+ C, Fr, T1, x.device, self.max_period
+ ) # (1, C, Fr, T1)
+ pos_emb_2d = rearrange(pos_emb_2d, "b c fr t1 -> b (t1 fr) c")
+ x = rearrange(x, "b c fr t1 -> b (t1 fr) c")
+ x = self.norm_in(x)
+ x = x + self.weight_pos_embed * pos_emb_2d
+
+ B, C, T2 = xt.shape
+ xt = rearrange(xt, "b c t2 -> b t2 c") # now T2, B, C
+ pos_emb = self._get_pos_embedding(T2, B, C, x.device)
+ pos_emb = rearrange(pos_emb, "t2 b c -> b t2 c")
+ xt = self.norm_in_t(xt)
+ xt = xt + self.weight_pos_embed * pos_emb
+
+ for idx in range(self.num_layers):
+ if idx % 2 == self.classic_parity:
+ x = self.layers[idx](x)
+ xt = self.layers_t[idx](xt)
+ else:
+ old_x = x
+ x = self.layers[idx](x, xt)
+ xt = self.layers_t[idx](xt, old_x)
+
+ x = rearrange(x, "b (t1 fr) c -> b c fr t1", t1=T1)
+ xt = rearrange(xt, "b t2 c -> b c t2")
+ return x, xt
+
+ def _get_pos_embedding(self, T, B, C, device):
+ if self.emb == "sin":
+ shift = random.randrange(self.sin_random_shift + 1)
+ pos_emb = create_sin_embedding(
+ T, C, shift=shift, device=device, max_period=self.max_period
+ )
+ elif self.emb == "cape":
+ if self.training:
+ pos_emb = create_sin_embedding_cape(
+ T,
+ C,
+ B,
+ device=device,
+ max_period=self.max_period,
+ mean_normalize=self.cape_mean_normalize,
+ augment=self.cape_augment,
+ max_global_shift=self.cape_glob_loc_scale[0],
+ max_local_shift=self.cape_glob_loc_scale[1],
+ max_scale=self.cape_glob_loc_scale[2],
+ )
+ else:
+ pos_emb = create_sin_embedding_cape(
+ T,
+ C,
+ B,
+ device=device,
+ max_period=self.max_period,
+ mean_normalize=self.cape_mean_normalize,
+ augment=False,
+ )
+
+ elif self.emb == "scaled":
+ pos = torch.arange(T, device=device)
+ pos_emb = self.position_embeddings(pos)[:, None]
+
+ return pos_emb
+
+ def make_optim_group(self):
+ group = {"params": list(self.parameters()), "weight_decay": self.weight_decay}
+ if self.lr is not None:
+ group["lr"] = self.lr
+ return group
+
+
+# Attention Modules
+
+
+class MultiheadAttention(nn.Module):
+ def __init__(
+ self,
+ embed_dim,
+ num_heads,
+ dropout=0.0,
+ bias=True,
+ add_bias_kv=False,
+ add_zero_attn=False,
+ kdim=None,
+ vdim=None,
+ batch_first=False,
+ auto_sparsity=None,
+ ):
+ super().__init__()
+ assert auto_sparsity is not None, "sanity check"
+ self.num_heads = num_heads
+ self.q = torch.nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.k = torch.nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.v = torch.nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.attn_drop = torch.nn.Dropout(dropout)
+ self.proj = torch.nn.Linear(embed_dim, embed_dim, bias)
+ self.proj_drop = torch.nn.Dropout(dropout)
+ self.batch_first = batch_first
+ self.auto_sparsity = auto_sparsity
+
+ def forward(
+ self,
+ query,
+ key,
+ value,
+ key_padding_mask=None,
+ need_weights=True,
+ attn_mask=None,
+ average_attn_weights=True,
+ ):
+
+ if not self.batch_first: # N, B, C
+ query = query.permute(1, 0, 2) # B, N_q, C
+ key = key.permute(1, 0, 2) # B, N_k, C
+ value = value.permute(1, 0, 2) # B, N_k, C
+ B, N_q, C = query.shape
+ B, N_k, C = key.shape
+
+ q = (
+ self.q(query)
+ .reshape(B, N_q, self.num_heads, C // self.num_heads)
+ .permute(0, 2, 1, 3)
+ )
+ q = q.flatten(0, 1)
+ k = (
+ self.k(key)
+ .reshape(B, N_k, self.num_heads, C // self.num_heads)
+ .permute(0, 2, 1, 3)
+ )
+ k = k.flatten(0, 1)
+ v = (
+ self.v(value)
+ .reshape(B, N_k, self.num_heads, C // self.num_heads)
+ .permute(0, 2, 1, 3)
+ )
+ v = v.flatten(0, 1)
+
+ if self.auto_sparsity:
+ assert attn_mask is None
+ x = dynamic_sparse_attention(q, k, v, sparsity=self.auto_sparsity)
+ else:
+ x = scaled_dot_product_attention(q, k, v, attn_mask, dropout=self.attn_drop)
+ x = x.reshape(B, self.num_heads, N_q, C // self.num_heads)
+
+ x = x.transpose(1, 2).reshape(B, N_q, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ if not self.batch_first:
+ x = x.permute(1, 0, 2)
+ return x, None
+
+
+def scaled_query_key_softmax(q, k, att_mask):
+ from xformers.ops import masked_matmul
+ q = q / (k.size(-1)) ** 0.5
+ att = masked_matmul(q, k.transpose(-2, -1), att_mask)
+ att = torch.nn.functional.softmax(att, -1)
+ return att
+
+
+def scaled_dot_product_attention(q, k, v, att_mask, dropout):
+ att = scaled_query_key_softmax(q, k, att_mask=att_mask)
+ att = dropout(att)
+ y = att @ v
+ return y
+
+
+def _compute_buckets(x, R):
+ qq = torch.einsum('btf,bfhi->bhti', x, R)
+ qq = torch.cat([qq, -qq], dim=-1)
+ buckets = qq.argmax(dim=-1)
+
+ return buckets.permute(0, 2, 1).byte().contiguous()
+
+
+def dynamic_sparse_attention(query, key, value, sparsity, infer_sparsity=True, attn_bias=None):
+ # assert False, "The code for the custom sparse kernel is not ready for release yet."
+ from xformers.ops import find_locations, sparse_memory_efficient_attention
+ n_hashes = 32
+ proj_size = 4
+ query, key, value = [x.contiguous() for x in [query, key, value]]
+ with torch.no_grad():
+ R = torch.randn(1, query.shape[-1], n_hashes, proj_size // 2, device=query.device)
+ bucket_query = _compute_buckets(query, R)
+ bucket_key = _compute_buckets(key, R)
+ row_offsets, column_indices = find_locations(
+ bucket_query, bucket_key, sparsity, infer_sparsity)
+ return sparse_memory_efficient_attention(
+ query, key, value, row_offsets, column_indices, attn_bias)
diff --git a/AIMeiSheng/demucs/utils.py b/AIMeiSheng/demucs/utils.py
new file mode 100644
index 0000000..a3f5993
--- /dev/null
+++ b/AIMeiSheng/demucs/utils.py
@@ -0,0 +1,149 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from collections import defaultdict
+from concurrent.futures import CancelledError
+from contextlib import contextmanager
+import math
+import os
+import tempfile
+import typing as tp
+
+import torch
+from torch.nn import functional as F
+from torch.utils.data import Subset
+
+
+def unfold(a, kernel_size, stride):
+ """Given input of size [*OT, T], output Tensor of size [*OT, F, K]
+ with K the kernel size, by extracting frames with the given stride.
+
+ This will pad the input so that `F = ceil(T / K)`.
+
+ see https://github.com/pytorch/pytorch/issues/60466
+ """
+ *shape, length = a.shape
+ n_frames = math.ceil(length / stride)
+ tgt_length = (n_frames - 1) * stride + kernel_size
+ a = F.pad(a, (0, tgt_length - length))
+ strides = list(a.stride())
+ assert strides[-1] == 1, 'data should be contiguous'
+ strides = strides[:-1] + [stride, 1]
+ return a.as_strided([*shape, n_frames, kernel_size], strides)
+
+
+def center_trim(tensor: torch.Tensor, reference: tp.Union[torch.Tensor, int]):
+ """
+ Center trim `tensor` with respect to `reference`, along the last dimension.
+ `reference` can also be a number, representing the length to trim to.
+ If the size difference != 0 mod 2, the extra sample is removed on the right side.
+ """
+ ref_size: int
+ if isinstance(reference, torch.Tensor):
+ ref_size = reference.size(-1)
+ else:
+ ref_size = reference
+ delta = tensor.size(-1) - ref_size
+ if delta < 0:
+ raise ValueError("tensor must be larger than reference. " f"Delta is {delta}.")
+ if delta:
+ tensor = tensor[..., delta // 2:-(delta - delta // 2)]
+ return tensor
+
+
+def pull_metric(history: tp.List[dict], name: str):
+ out = []
+ for metrics in history:
+ metric = metrics
+ for part in name.split("."):
+ metric = metric[part]
+ out.append(metric)
+ return out
+
+
+def EMA(beta: float = 1):
+ """
+ Exponential Moving Average callback.
+ Returns a single function that can be called to repeatidly update the EMA
+ with a dict of metrics. The callback will return
+ the new averaged dict of metrics.
+
+ Note that for `beta=1`, this is just plain averaging.
+ """
+ fix: tp.Dict[str, float] = defaultdict(float)
+ total: tp.Dict[str, float] = defaultdict(float)
+
+ def _update(metrics: dict, weight: float = 1) -> dict:
+ nonlocal total, fix
+ for key, value in metrics.items():
+ total[key] = total[key] * beta + weight * float(value)
+ fix[key] = fix[key] * beta + weight
+ return {key: tot / fix[key] for key, tot in total.items()}
+ return _update
+
+
+def sizeof_fmt(num: float, suffix: str = 'B'):
+ """
+ Given `num` bytes, return human readable size.
+ Taken from https://stackoverflow.com/a/1094933
+ """
+ for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']:
+ if abs(num) < 1024.0:
+ return "%3.1f%s%s" % (num, unit, suffix)
+ num /= 1024.0
+ return "%.1f%s%s" % (num, 'Yi', suffix)
+
+
+@contextmanager
+def temp_filenames(count: int, delete=True):
+ names = []
+ try:
+ for _ in range(count):
+ names.append(tempfile.NamedTemporaryFile(delete=False).name)
+ yield names
+ finally:
+ if delete:
+ for name in names:
+ os.unlink(name)
+
+
+def random_subset(dataset, max_samples: int, seed: int = 42):
+ if max_samples >= len(dataset):
+ return dataset
+
+ generator = torch.Generator().manual_seed(seed)
+ perm = torch.randperm(len(dataset), generator=generator)
+ return Subset(dataset, perm[:max_samples].tolist())
+
+
+class DummyPoolExecutor:
+ class DummyResult:
+ def __init__(self, func, _dict, *args, **kwargs):
+ self.func = func
+ self._dict = _dict
+ self.args = args
+ self.kwargs = kwargs
+
+ def result(self):
+ if self._dict["run"]:
+ return self.func(*self.args, **self.kwargs)
+ else:
+ raise CancelledError()
+
+ def __init__(self, workers=0):
+ self._dict = {"run": True}
+
+ def submit(self, func, *args, **kwargs):
+ return DummyPoolExecutor.DummyResult(func, self._dict, *args, **kwargs)
+
+ def shutdown(self, *_, **__):
+ self._dict["run"] = False
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, exc_tb):
+ return
diff --git a/AIMeiSheng/demucs/wav.py b/AIMeiSheng/demucs/wav.py
new file mode 100644
index 0000000..6acb9b5
--- /dev/null
+++ b/AIMeiSheng/demucs/wav.py
@@ -0,0 +1,254 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Loading wav based datasets, including MusdbHQ."""
+
+from collections import OrderedDict
+import hashlib
+import math
+import json
+import os
+from pathlib import Path
+import tqdm
+
+import musdb
+import julius
+import torch as th
+from torch import distributed
+import torchaudio as ta
+from torch.nn import functional as F
+
+from .audio import convert_audio_channels
+from . import distrib
+
+MIXTURE = "mixture"
+EXT = ".wav"
+
+
+def _track_metadata(track, sources, normalize=True, ext=EXT):
+ track_length = None
+ track_samplerate = None
+ mean = 0
+ std = 1
+ for source in sources + [MIXTURE]:
+ file = track / f"{source}{ext}"
+ if source == MIXTURE and not file.exists():
+ audio = 0
+ for sub_source in sources:
+ sub_file = track / f"{sub_source}{ext}"
+ sub_audio, sr = ta.load(sub_file)
+ audio += sub_audio
+ would_clip = audio.abs().max() >= 1
+ if would_clip:
+ assert ta.get_audio_backend() == 'soundfile', 'use dset.backend=soundfile'
+ ta.save(file, audio, sr, encoding='PCM_F')
+
+ try:
+ info = ta.info(str(file))
+ except RuntimeError:
+ print(file)
+ raise
+ length = info.num_frames
+ if track_length is None:
+ track_length = length
+ track_samplerate = info.sample_rate
+ elif track_length != length:
+ raise ValueError(
+ f"Invalid length for file {file}: "
+ f"expecting {track_length} but got {length}.")
+ elif info.sample_rate != track_samplerate:
+ raise ValueError(
+ f"Invalid sample rate for file {file}: "
+ f"expecting {track_samplerate} but got {info.sample_rate}.")
+ if source == MIXTURE and normalize:
+ try:
+ wav, _ = ta.load(str(file))
+ except RuntimeError:
+ print(file)
+ raise
+ wav = wav.mean(0)
+ mean = wav.mean().item()
+ std = wav.std().item()
+
+ return {"length": length, "mean": mean, "std": std, "samplerate": track_samplerate}
+
+
+def build_metadata(path, sources, normalize=True, ext=EXT):
+ """
+ Build the metadata for `Wavset`.
+
+ Args:
+ path (str or Path): path to dataset.
+ sources (list[str]): list of sources to look for.
+ normalize (bool): if True, loads full track and store normalization
+ values based on the mixture file.
+ ext (str): extension of audio files (default is .wav).
+ """
+
+ meta = {}
+ path = Path(path)
+ pendings = []
+ from concurrent.futures import ThreadPoolExecutor
+ with ThreadPoolExecutor(8) as pool:
+ for root, folders, files in os.walk(path, followlinks=True):
+ root = Path(root)
+ if root.name.startswith('.') or folders or root == path:
+ continue
+ name = str(root.relative_to(path))
+ pendings.append((name, pool.submit(_track_metadata, root, sources, normalize, ext)))
+ # meta[name] = _track_metadata(root, sources, normalize, ext)
+ for name, pending in tqdm.tqdm(pendings, ncols=120):
+ meta[name] = pending.result()
+ return meta
+
+
+class Wavset:
+ def __init__(
+ self,
+ root, metadata, sources,
+ segment=None, shift=None, normalize=True,
+ samplerate=44100, channels=2, ext=EXT):
+ """
+ Waveset (or mp3 set for that matter). Can be used to train
+ with arbitrary sources. Each track should be one folder inside of `path`.
+ The folder should contain files named `{source}.{ext}`.
+
+ Args:
+ root (Path or str): root folder for the dataset.
+ metadata (dict): output from `build_metadata`.
+ sources (list[str]): list of source names.
+ segment (None or float): segment length in seconds. If `None`, returns entire tracks.
+ shift (None or float): stride in seconds bewteen samples.
+ normalize (bool): normalizes input audio, **based on the metadata content**,
+ i.e. the entire track is normalized, not individual extracts.
+ samplerate (int): target sample rate. if the file sample rate
+ is different, it will be resampled on the fly.
+ channels (int): target nb of channels. if different, will be
+ changed onthe fly.
+ ext (str): extension for audio files (default is .wav).
+
+ samplerate and channels are converted on the fly.
+ """
+ self.root = Path(root)
+ self.metadata = OrderedDict(metadata)
+ self.segment = segment
+ self.shift = shift or segment
+ self.normalize = normalize
+ self.sources = sources
+ self.channels = channels
+ self.samplerate = samplerate
+ self.ext = ext
+ self.num_examples = []
+ for name, meta in self.metadata.items():
+ track_duration = meta['length'] / meta['samplerate']
+ if segment is None or track_duration < segment:
+ examples = 1
+ else:
+ examples = int(math.ceil((track_duration - self.segment) / self.shift) + 1)
+ self.num_examples.append(examples)
+
+ def __len__(self):
+ return sum(self.num_examples)
+
+ def get_file(self, name, source):
+ return self.root / name / f"{source}{self.ext}"
+
+ def __getitem__(self, index):
+ for name, examples in zip(self.metadata, self.num_examples):
+ if index >= examples:
+ index -= examples
+ continue
+ meta = self.metadata[name]
+ num_frames = -1
+ offset = 0
+ if self.segment is not None:
+ offset = int(meta['samplerate'] * self.shift * index)
+ num_frames = int(math.ceil(meta['samplerate'] * self.segment))
+ wavs = []
+ for source in self.sources:
+ file = self.get_file(name, source)
+ wav, _ = ta.load(str(file), frame_offset=offset, num_frames=num_frames)
+ wav = convert_audio_channels(wav, self.channels)
+ wavs.append(wav)
+
+ example = th.stack(wavs)
+ example = julius.resample_frac(example, meta['samplerate'], self.samplerate)
+ if self.normalize:
+ example = (example - meta['mean']) / meta['std']
+ if self.segment:
+ length = int(self.segment * self.samplerate)
+ example = example[..., :length]
+ example = F.pad(example, (0, length - example.shape[-1]))
+ return example
+
+
+def get_wav_datasets(args, name='wav'):
+ """Extract the wav datasets from the XP arguments."""
+ path = getattr(args, name)
+ sig = hashlib.sha1(str(path).encode()).hexdigest()[:8]
+ metadata_file = Path(args.metadata) / ('wav_' + sig + ".json")
+ train_path = Path(path) / "train"
+ valid_path = Path(path) / "valid"
+ if not metadata_file.is_file() and distrib.rank == 0:
+ metadata_file.parent.mkdir(exist_ok=True, parents=True)
+ train = build_metadata(train_path, args.sources)
+ valid = build_metadata(valid_path, args.sources)
+ json.dump([train, valid], open(metadata_file, "w"))
+ if distrib.world_size > 1:
+ distributed.barrier()
+ train, valid = json.load(open(metadata_file))
+ if args.full_cv:
+ kw_cv = {}
+ else:
+ kw_cv = {'segment': args.segment, 'shift': args.shift}
+ train_set = Wavset(train_path, train, args.sources,
+ segment=args.segment, shift=args.shift,
+ samplerate=args.samplerate, channels=args.channels,
+ normalize=args.normalize)
+ valid_set = Wavset(valid_path, valid, [MIXTURE] + list(args.sources),
+ samplerate=args.samplerate, channels=args.channels,
+ normalize=args.normalize, **kw_cv)
+ return train_set, valid_set
+
+
+def _get_musdb_valid():
+ # Return musdb valid set.
+ import yaml
+ setup_path = Path(musdb.__path__[0]) / 'configs' / 'mus.yaml'
+ setup = yaml.safe_load(open(setup_path, 'r'))
+ return setup['validation_tracks']
+
+
+def get_musdb_wav_datasets(args):
+ """Extract the musdb dataset from the XP arguments."""
+ sig = hashlib.sha1(str(args.musdb).encode()).hexdigest()[:8]
+ metadata_file = Path(args.metadata) / ('musdb_' + sig + ".json")
+ root = Path(args.musdb) / "train"
+ if not metadata_file.is_file() and distrib.rank == 0:
+ metadata_file.parent.mkdir(exist_ok=True, parents=True)
+ metadata = build_metadata(root, args.sources)
+ json.dump(metadata, open(metadata_file, "w"))
+ if distrib.world_size > 1:
+ distributed.barrier()
+ metadata = json.load(open(metadata_file))
+
+ valid_tracks = _get_musdb_valid()
+ if args.train_valid:
+ metadata_train = metadata
+ else:
+ metadata_train = {name: meta for name, meta in metadata.items() if name not in valid_tracks}
+ metadata_valid = {name: meta for name, meta in metadata.items() if name in valid_tracks}
+ if args.full_cv:
+ kw_cv = {}
+ else:
+ kw_cv = {'segment': args.segment, 'shift': args.shift}
+ train_set = Wavset(root, metadata_train, args.sources,
+ segment=args.segment, shift=args.shift,
+ samplerate=args.samplerate, channels=args.channels,
+ normalize=args.normalize)
+ valid_set = Wavset(root, metadata_valid, [MIXTURE] + list(args.sources),
+ samplerate=args.samplerate, channels=args.channels,
+ normalize=args.normalize, **kw_cv)
+ return train_set, valid_set
diff --git a/AIMeiSheng/demucs/wdemucs.py b/AIMeiSheng/demucs/wdemucs.py
new file mode 100644
index 0000000..03d6dd3
--- /dev/null
+++ b/AIMeiSheng/demucs/wdemucs.py
@@ -0,0 +1,9 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# For compat
+from .hdemucs import HDemucs
+
+WDemucs = HDemucs
diff --git a/AIMeiSheng/docker_demo/Dockerfile b/AIMeiSheng/docker_demo/Dockerfile
index 94fb28a..bdb7ab3 100644
--- a/AIMeiSheng/docker_demo/Dockerfile
+++ b/AIMeiSheng/docker_demo/Dockerfile
@@ -1,29 +1,31 @@
# 系统版本 CUDA Version 11.8.0
# NAME="CentOS Linux" VERSION="7 (Core)"
# FROM starmaker.tencentcloudcr.com/starmaker/av/av:1.1
# 基础镜像, python3.9,cuda118,centos7,外加ffmpeg
#FROM starmaker.tencentcloudcr.com/starmaker/av/av_base:1.0
FROM registry.ushow.media/av/av_base:1.0
#FROM av_base_test:1.0
RUN source /etc/profile && sed -i 's|mirrorlist=|#mirrorlist=|g' /etc/yum.repos.d/CentOS-Base.repo && sed -i 's|#baseurl=http://mirror.centos.org|baseurl=http://vault.centos.org|g' /etc/yum.repos.d/CentOS-Base.repo && yum clean all && yum install -y unzip && yum install -y libsndfile && yum install -y libsamplerate libsamplerate-devel
RUN source /etc/profile && pip3 install librosa==0.9.1 && pip3 install gradio && pip3 install torch==2.1.2 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
RUN source /etc/profile && pip3 install urllib3==1.26.15 && pip3 install coscmd && coscmd config -a AKIDoQmshFWXGitnQmrfCTYNwEExPaU6RVHm -s F9n9E2ZonWy93f04qMaYFfogHadPt62h -b log-sg-1256122840 -r ap-singapore
RUN source /etc/profile && pip3 install asteroid-filterbanks
RUN source /etc/profile && pip3 install praat-parselmouth==0.4.3
RUN source /etc/profile && pip3 install pyworld
RUN source /etc/profile && pip3 install faiss-cpu
RUN source /etc/profile && pip3 install torchcrepe
RUN source /etc/profile && pip3 install thop
RUN source /etc/profile && pip3 install ffmpeg-python
RUN source /etc/profile && pip3 install pip3==24.0
RUN source /etc/profile && pip3 install fairseq==0.12.2
RUN source /etc/profile && pip3 install redis==4.5.0
RUN source /etc/profile && pip3 install numpy==1.26.4
+RUN source /etc/profile && pip3 install demucs
+
COPY ./ /data/code/
WORKDIR /data/code
CMD ["/bin/bash", "-c", "source /etc/profile; export PYTHONPATH=/data/code; cd /data/code/AIMeiSheng/docker_demo; python3 offline_server.py"]
-#CMD ["/bin/bash", "-c", "source /etc/profile; export PYTHONPATH=/data/code; cd /data/code/AIMeiSheng/docker_demo; python3 tmp.py"]
\ No newline at end of file
+#CMD ["/bin/bash", "-c", "source /etc/profile; export PYTHONPATH=/data/code; cd /data/code/AIMeiSheng/docker_demo; python3 tmp.py"]
diff --git a/AIMeiSheng/docker_demo/common.py b/AIMeiSheng/docker_demo/common.py
index 49b3628..2b2a7b3 100644
--- a/AIMeiSheng/docker_demo/common.py
+++ b/AIMeiSheng/docker_demo/common.py
@@ -1,130 +1,132 @@
import os
import sys
import time
# import logging
import urllib, urllib.request
# 测试/正式环境
gs_prod = True
# if len(sys.argv) > 1 and sys.argv[1] == "prod":
# gs_prod = True
# print(gs_prod)
gs_tmp_dir = "/data/ai_meisheng_tmp"
gs_model_dir = "/data/ai_meisheng_models"
gs_resource_cache_dir = "/tmp/ai_meisheng_resource_cache"
gs_embed_model_path = os.path.join(gs_model_dir, "RawNet3/models/weights/model.pt")
gs_svc_model_path = os.path.join(gs_model_dir,
"weights/xusong_v2_org_version_alldata_embed_spkenx200x_double_e14_s90706.pth")
gs_hubert_model_path = os.path.join(gs_model_dir, "hubert.pt")
gs_rmvpe_model_path = os.path.join(gs_model_dir, "rmvpe.pt")
gs_embed_model_spk_path = os.path.join(gs_model_dir, "SpeakerEncoder/pretrained_model/best_model.pth.tar")
gs_embed_config_spk_path = os.path.join(gs_model_dir, "SpeakerEncoder/pretrained_model/config.json")
+gs_demucs_model_path = os.path.join(gs_model_dir, "demucs_model/")
+
# errcode
gs_err_code_success = 0
gs_err_code_download_vocal = 100
gs_err_code_download_svc_url = 101
gs_err_code_svc_process = 102
gs_err_code_transcode = 103
gs_err_code_volume_adjust = 104
gs_err_code_upload = 105
gs_err_code_params = 106
gs_err_code_pending = 107
gs_err_code_target_silence = 108
gs_err_code_too_many_connections = 429
gs_err_code_gender_classify = 430
gs_err_code_vocal_ratio = 431 #人声占比
gs_redis_conf = {
"host": "av-credis.starmaker.co",
"port": 6379,
"pwd": "lKoWEhz%jxTO",
}
# gs_server_redis_conf = {
# "producer": "dev_ai_meisheng_producer", # 输入的队列
# "ai_meisheng_key_prefix": "dev_ai_meisheng_key_", # 存储结果情况
# }
gs_server_redis_conf = {
"producer": "test_ai_meisheng_producer", # 输入的队列
"ai_meisheng_key_prefix": "test_ai_meisheng_key_", # 存储结果情况
}
if gs_prod:
gs_server_redis_conf = {
"producer": "ai_meisheng_producer", # 输入的队列
"ai_meisheng_key_prefix": "ai_meisheng_key_", # 存储结果情况
}
gs_feishu_conf = {
"url": "http://sg-prod-songbook-webmp-1:8000/api/feishu/people",
"users": [
"18810833785", # 杨建利
"17778007843", # 王健军
"18612496315", # 郭子豪
"18600542290" # 方兵晓
]
}
def download2disk(url, dst_path):
try:
urllib.request.urlretrieve(url, dst_path)
return os.path.exists(dst_path)
except Exception as ex:
print(f"download url={url} error", ex)
return False
def exec_cmd(cmd):
# gs_logger.info(cmd)
print(cmd)
ret = os.system(cmd)
if ret != 0:
return False
return True
def exec_cmd_and_result(cmd):
r = os.popen(cmd)
text = r.read()
r.close()
return text
def upload_file2cos(key, file_path, region='ap-singapore', bucket_name='av-audit-sync-sg-1256122840'):
"""
将文件上传到cos
:param key: 桶上的具体地址
:param file_path: 本地文件地址
:param region: 区域
:param bucket_name: 桶地址
:return:
"""
gs_coscmd = "coscmd"
gs_coscmd_conf = "~/.cos.conf"
cmd = "{} -c {} -r {} -b {} upload {} {}".format(gs_coscmd, gs_coscmd_conf, region, bucket_name, file_path, key)
if exec_cmd(cmd):
cmd = "{} -c {} -r {} -b {} info {}".format(gs_coscmd, gs_coscmd_conf, region, bucket_name, key) \
+ "| grep Content-Length |awk \'{print $2}\'"
res_str = exec_cmd_and_result(cmd)
# logging.info("{},res={}".format(key, res_str))
size = float(res_str)
if size > 0:
return True
return False
return False
def check_input(input_data):
key_list = ["record_song_url", "target_url", "start", "end", "vocal_loudness", "female_recording_url",
"male_recording_url"]
for key in key_list:
if key not in input_data.keys():
return False
return True
diff --git a/AIMeiSheng/docker_demo/svc_online.py b/AIMeiSheng/docker_demo/svc_online.py
index 421b5f6..178b32d 100644
--- a/AIMeiSheng/docker_demo/svc_online.py
+++ b/AIMeiSheng/docker_demo/svc_online.py
@@ -1,207 +1,220 @@
# -*- coding: UTF-8 -*-
"""
SVC的核心处理逻辑
"""
import os
import time
import socket
import shutil
import hashlib
from AIMeiSheng.meisheng_svc_final import load_model, process_svc_online
from AIMeiSheng.cos_similar_ui_zoom import cos_similar
-from AIMeiSheng.meisheng_env_preparex import meisheng_env_prepare
+from AIMeiSheng.meisheng_env_preparex import meisheng_env_prepare, demucs_env_prepare
from AIMeiSheng.voice_classification.online.voice_class_online_fang import VoiceClass, download_volume_balanced
from AIMeiSheng.docker_demo.common import *
+# from AIMeiSheng.demucs_process_one import demucs_process_one
+from AIMeiSheng.separate_demucs import main_seperator
import logging
hostname = socket.gethostname()
log_file_name = f"{os.path.dirname(os.path.abspath(__file__))}/av_meisheng_{hostname}.log"
# 设置logger
svc_offline_logger = logging.getLogger("svc_offline")
file_handler = logging.FileHandler(log_file_name)
file_handler.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s', datefmt='%Y-%m-%d %I:%M:%S')
file_handler.setFormatter(formatter)
# if gs_prod:
# svc_offline_logger.addHandler(file_handler)
if os.path.exists(gs_tmp_dir):
shutil.rmtree(gs_tmp_dir)
os.makedirs(gs_model_dir, exist_ok=True)
os.makedirs(gs_resource_cache_dir, exist_ok=True)
# 预设参数
gs_gender_models_url = "https://av-audit-sync-sg-1256122840.cos.ap-singapore.myqcloud.com/hub/voice_classification/models.zip"
gs_volume_bin_url = "https://av-audit-sync-sg-1256122840.cos.ap-singapore.myqcloud.com/dataset/AIMeiSheng/ebur128_tool/v1/ebur128_tool"
class GSWorkerAttr:
def __init__(self, input_data):
# 取出输入资源
vocal_url = input_data["record_song_url"]
target_url = input_data["target_url"]
start = input_data["start"] # 单位是ms
end = input_data["end"] # 单位是ms
vocal_loudness = input_data["vocal_loudness"]
female_recording_url = input_data["female_recording_url"]
male_recording_url = input_data["male_recording_url"]
self.distinct_id = hashlib.md5(vocal_url.encode()).hexdigest()
self.tmp_dir = os.path.join(gs_tmp_dir, self.distinct_id)
if os.path.exists(self.tmp_dir):
shutil.rmtree(self.tmp_dir)
os.makedirs(self.tmp_dir)
self.vocal_url = vocal_url
self.target_url = target_url
ext = vocal_url.split(".")[-1]
self.vocal_path = os.path.join(self.tmp_dir, self.distinct_id + f"_in.{ext}")
+ self.vocal_demucs_path = os.path.join(self.tmp_dir, self.distinct_id + f"_in_demucs.wav")
self.target_wav_path = os.path.join(self.tmp_dir, self.distinct_id + "_out.wav")
self.target_wav_ad_path = os.path.join(self.tmp_dir, self.distinct_id + "_out_ad.wav")
self.target_path = os.path.join(self.tmp_dir, self.distinct_id + "_out.m4a")
self.female_svc_source_url = female_recording_url
self.male_svc_source_url = male_recording_url
ext = female_recording_url.split(".")[-1]
self.female_svc_source_path = os.path.join(self.tmp_dir, self.distinct_id + f"_female.{ext}")
ext = male_recording_url.split(".")[-1]
self.male_svc_source_path = os.path.join(self.tmp_dir, self.distinct_id + f"_male.{ext}")
# self.female_svc_source_path = os.path.join(gs_resource_cache_dir,
# hashlib.md5(female_recording_url.encode()).hexdigest() + "." + ext)
# ext = male_recording_url.split(".")[-1]
# self.male_svc_source_path = os.path.join(gs_resource_cache_dir,
# hashlib.md5(male_recording_url.encode()).hexdigest() + "." + ext)
self.st_tm = start
self.ed_tm = end
self.target_loudness = vocal_loudness
def log_info_name(self):
return f"d_id={self.distinct_id}, vocal_url={self.vocal_url}"
def rm_cache(self):
if os.path.exists(self.tmp_dir):
shutil.rmtree(self.tmp_dir)
def init_gender_model():
"""
下载模型
:return:
"""
dst_model_dir = os.path.join(gs_model_dir, "voice_classification")
if not os.path.exists(dst_model_dir):
dst_zip_path = os.path.join(gs_model_dir, "models.zip")
if not download2disk(gs_gender_models_url, dst_zip_path):
svc_offline_logger.fatal(f"download gender_model err={gs_gender_models_url}")
cmd = f"cd {gs_model_dir}; unzip {dst_zip_path}; mv models voice_classification; rm -f {dst_zip_path}"
os.system(cmd)
if not os.path.exists(dst_model_dir):
svc_offline_logger.fatal(f"unzip {dst_zip_path} err")
music_voice_pure_model = os.path.join(dst_model_dir, "voice_005_rec_v5.pth")
music_voice_no_pure_model = os.path.join(dst_model_dir, "voice_10_v5.pth")
gender_pure_model = os.path.join(dst_model_dir, "gender_8k_ratev5_v6_adam.pth")
gender_no_pure_model = os.path.join(dst_model_dir, "gender_8k_v6_adam.pth")
vc = VoiceClass(music_voice_pure_model, music_voice_no_pure_model, gender_pure_model, gender_no_pure_model)
return vc
+def init_demucs_seperator():
+ demucs_env_prepare(logging, gs_demucs_model_path)
+ seperator = main_seperator()
+ return seperator
def init_svc_model():
meisheng_env_prepare(logging, gs_model_dir)
embed_model, hubert_model = load_model()
cs_sim = cos_similar()
return embed_model, hubert_model, cs_sim
def download_volume_adjustment():
"""
下载音量调整工具
:return:
"""
volume_bin_path = os.path.join(gs_model_dir, "ebur128_tool")
if not os.path.exists(volume_bin_path):
if not download2disk(gs_volume_bin_url, volume_bin_path):
svc_offline_logger.fatal(f"download volume_bin err={gs_volume_bin_url}")
os.system(f"chmod +x {volume_bin_path}")
def volume_adjustment(wav_path, target_loudness, out_path):
"""
音量调整
:param wav_path:
:param target_loudness:
:param out_path:
:return:
"""
volume_bin_path = os.path.join(gs_model_dir, "ebur128_tool")
cmd = f"{volume_bin_path} {wav_path} {target_loudness} {out_path}"
os.system(cmd)
class SVCOnline:
def __init__(self):
st = time.time()
self.gender_model = init_gender_model()
self.embed_model, self.hubert_model, self.cs_sim = init_svc_model()
+ self.demucs_seprator = init_demucs_seperator()
download_volume_adjustment()
download_volume_balanced()
svc_offline_logger.info(f"svc init finished, sp = {time.time() - st}")
def gender_process(self, worker_attr):
st = time.time()
gender, female_rate, is_pure = self.gender_model.process(worker_attr.vocal_path)
svc_offline_logger.info(
f"{worker_attr.vocal_url}, gender={gender}, female_rate={female_rate}, is_pure={is_pure}, "
f"gender_process sp = {time.time() - st}")
if gender == 0:
gender = 'female'
elif gender == 1:
gender = 'male'
elif female_rate == None:
gender = 'male'
return gender, gs_err_code_gender_classify
elif female_rate > 0.5:
gender = 'female'
else:
gender = 'male'
if gender == 'female':
if self.gender_model.vocal_ratio < 0.5:
print(f"@@@ vocal_ratio: {self.gender_model.vocal_ratio}, gender : {gender}, gs_err_code_vocal_ratio : {gs_err_code_vocal_ratio}")
- return gender, gs_err_code_vocal_ratio
+ # if not demucs_process_one(worker_attr.vocal_path, worker_attr.vocal_demucs_path):
+ if not self.demucs_seprator.process_one(worker_attr.vocal_path, worker_attr.vocal_demucs_path):
+ return gender, gs_err_code_vocal_ratio
else:
if self.gender_model.vocal_ratio < 0.6:
print(f"@@@ vocal_ratio: {self.gender_model.vocal_ratio}, gender : {gender}, gs_err_code_vocal_ratio : {gs_err_code_vocal_ratio}")
- return gender, gs_err_code_vocal_ratio
+ # if not demucs_process_one(worker_attr.vocal_path, worker_attr.vocal_demucs_path):
+ if not self.demucs_seprator.process_one(worker_attr.vocal_path, worker_attr.vocal_demucs_path):
+ return gender, gs_err_code_vocal_ratio
+
svc_offline_logger.info(f"{worker_attr.vocal_url}, modified gender={gender}")
# err = gs_err_code_success
# if female_rate == -1:
# err = gs_err_code_target_silence
return gender, gs_err_code_success
def process(self, worker_attr):
gender, err = self.gender_process(worker_attr)
if err != gs_err_code_success:
return gender, err
song_path = worker_attr.female_svc_source_path
if gender == "male":
song_path = worker_attr.male_svc_source_path
params = {'gender': gender, 'tst': worker_attr.st_tm, "tnd": worker_attr.ed_tm, 'delay': 0, 'song_path': None}
st = time.time()
err_code = process_svc_online(song_path, worker_attr.vocal_path, worker_attr.target_wav_path, self.embed_model,
self.hubert_model, self.cs_sim, params)
svc_offline_logger.info(f"{worker_attr.vocal_url}, err_code={err_code} process svc sp = {time.time() - st}")
return gender, err_code
diff --git a/AIMeiSheng/meisheng_env_preparex.py b/AIMeiSheng/meisheng_env_preparex.py
index 079ba38..f37fc2d 100644
--- a/AIMeiSheng/meisheng_env_preparex.py
+++ b/AIMeiSheng/meisheng_env_preparex.py
@@ -1,56 +1,71 @@
import os
from AIMeiSheng.docker_demo.common import (gs_svc_model_path, gs_hubert_model_path, gs_embed_model_path,gs_embed_model_spk_path, gs_embed_config_spk_path, gs_rmvpe_model_path, download2disk)
def meisheng_env_prepare(logging, AIMeiSheng_Path='./'):
cos_path = "https://av-audit-sync-sg-1256122840.cos.ap-singapore.myqcloud.com/dataset/AIMeiSheng/"
rmvpe_model_url = cos_path + "rmvpe.pt"
if not os.path.exists(gs_rmvpe_model_path):
if not download2disk(rmvpe_model_url, gs_rmvpe_model_path):
logging.fatal(f"download rmvpe_model err={rmvpe_model_url}")
gs_hubert_model_url = cos_path + "hubert_base.pt"
if not os.path.exists(gs_hubert_model_path):
if not download2disk(gs_hubert_model_url, gs_hubert_model_path):
logging.fatal(f"download hubert_model err={gs_hubert_model_url}")
#model_svc = "xusong_v2_org_version_alldata_embed1_enzx_diff_fi_e15_s244110.pth"
#model_svc = "xusong_v2_org_version_alldata_embed1_enzx_diff_ocean_ctl_enc_e22_s363704.pth"
#model_svc = "xusong_v2_org_version_alldata_embed_spkenx200x_vocal_e22_s95040.pth"
model_svc = "xusong_v2_org_version_alldata_embed_spkenx200x_double_e14_s90706.pth"
base_dir = os.path.dirname(gs_svc_model_path)
os.makedirs(base_dir, exist_ok=True)
svc_model_url = cos_path + model_svc
if not os.path.exists(gs_svc_model_path):
if not download2disk(svc_model_url, gs_svc_model_path):
logging.fatal(f"download svc_model err={svc_model_url}")
model_embed = "model.pt"
base_dir = os.path.dirname(gs_embed_model_path)
os.makedirs(base_dir, exist_ok=True)
embed_model_url = cos_path + model_embed
if not os.path.exists(gs_embed_model_path):
if not download2disk(embed_model_url, gs_embed_model_path):
logging.fatal(f"download embed_model err={embed_model_url}")
model_spk_embed = "best_model.pth.tar"
base_dir = os.path.dirname(gs_embed_model_spk_path)
os.makedirs(base_dir, exist_ok=True)
embed_model_url = cos_path + model_spk_embed
if not os.path.exists(gs_embed_model_spk_path):
if not download2disk(embed_model_url, gs_embed_model_spk_path):
logging.fatal(f"download embed_model err={embed_model_url}")
model_spk_embed_cfg = "config.json"
base_dir = os.path.dirname(gs_embed_config_spk_path)
os.makedirs(base_dir, exist_ok=True)
embed_model_url = cos_path + model_spk_embed_cfg
if not os.path.exists(gs_embed_config_spk_path):
if not download2disk(embed_model_url, gs_embed_config_spk_path):
logging.fatal(f"download embed_model err={embed_model_url}")
+
+def demucs_env_prepare(logging,gs_demucs_model_path):
+ cos_path = "https://av-audit-sync-sg-1256122840.cos.ap-singapore.myqcloud.com/dataset/AIMeiSheng/"
+
+ model_demucs_name_list = ["e51eebcc-c1b80bdd.th", "a1d90b5c-ae9d2452.th", "5d2d6c55-db83574e.th", "cfa93e08-61801ae1.th"]
+ if not os.path.exists(gs_demucs_model_path):
+ os.makedirs(gs_demucs_model_path, exist_ok=True)
+ for model_part in model_demucs_name_list:
+ demucs_model_cos_url = cos_path + model_part
+ gs_demucs_model_path_tmp = gs_demucs_model_path + model_part
+ if not download2disk(demucs_model_cos_url, gs_demucs_model_path_tmp):
+ logging.fatal(f"download embed_model err={demucs_model_cos_url}")
+
+
+
if __name__ == "__main__":
meisheng_env_prepare()
diff --git a/AIMeiSheng/separate_demucs.py b/AIMeiSheng/separate_demucs.py
new file mode 100644
index 0000000..99a28a8
--- /dev/null
+++ b/AIMeiSheng/separate_demucs.py
@@ -0,0 +1,264 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import sys, os
+from pathlib import Path
+
+from dora.log import fatal
+import torch as th
+
+from AIMeiSheng.demucs.api import Separator, save_audio, list_models
+
+from AIMeiSheng.demucs.apply import BagOfModels
+from AIMeiSheng.demucs.htdemucs import HTDemucs
+from AIMeiSheng.demucs.pretrained import add_model_flags, ModelLoadingError
+
+
+def get_parser():
+ parser = argparse.ArgumentParser("demucs.separate",
+ description="Separate the sources for the given tracks")
+ parser.add_argument("tracks", nargs='*', type=Path, default=[], help='Path to tracks')
+ add_model_flags(parser)
+ parser.add_argument("--list-models", action="store_true", help="List available models "
+ "from current repo and exit")
+ parser.add_argument("-v", "--verbose", action="store_true")
+ parser.add_argument("-o",
+ "--out",
+ type=Path,
+ default=Path("separated"),
+ help="Folder where to put extracted tracks. A subfolder "
+ "with the model name will be created.")
+
+ parser.add_argument("--filename",
+ default="{track}/{stem}.{ext}",
+ help="Set the name of output file. \n"
+ 'Use "{track}", "{trackext}", "{stem}", "{ext}" to use '
+ "variables of track name without extension, track extension, "
+ "stem name and default output file extension. \n"
+ 'Default is "{track}/{stem}.{ext}".')
+ parser.add_argument("-d",
+ "--device",
+ default=(
+ "cuda"
+ if th.cuda.is_available()
+ else "mps"
+ if th.backends.mps.is_available()
+ else "cpu"
+ ),
+ help="Device to use, default is cuda if available else cpu")
+ parser.add_argument("--shifts",
+ default=1,
+ type=int,
+ help="Number of random shifts for equivariant stabilization."
+ "Increase separation time but improves quality for Demucs. 10 was used "
+ "in the original paper.")
+ parser.add_argument("--overlap",
+ default=0.25,
+ type=float,
+ help="Overlap between the splits.")
+ split_group = parser.add_mutually_exclusive_group()
+ split_group.add_argument("--no-split",
+ action="store_false",
+ dest="split",
+ default=True,
+ help="Doesn't split audio in chunks. "
+ "This can use large amounts of memory.")
+ split_group.add_argument("--segment", type=int,
+ help="Set split size of each chunk. "
+ "This can help save memory of graphic card. ")
+ parser.add_argument("--two-stems",
+ dest="stem", metavar="STEM",
+ help="Only separate audio into {STEM} and no_{STEM}. ")
+ parser.add_argument("--other-method", dest="other_method", choices=["none", "add", "minus"],
+ default="add", help='Decide how to get "no_{STEM}". "none" will not save '
+ '"no_{STEM}". "add" will add all the other stems. "minus" will use the '
+ "original track minus the selected stem.")
+ depth_group = parser.add_mutually_exclusive_group()
+ depth_group.add_argument("--int24", action="store_true",
+ help="Save wav output as 24 bits wav.")
+ depth_group.add_argument("--float32", action="store_true",
+ help="Save wav output as float32 (2x bigger).")
+ parser.add_argument("--clip-mode", default="rescale", choices=["rescale", "clamp", "none"],
+ help="Strategy for avoiding clipping: rescaling entire signal "
+ "if necessary (rescale) or hard clipping (clamp).")
+ format_group = parser.add_mutually_exclusive_group()
+ format_group.add_argument("--flac", action="store_true",
+ help="Convert the output wavs to flac.")
+ format_group.add_argument("--mp3", action="store_true",
+ help="Convert the output wavs to mp3.")
+ parser.add_argument("--mp3-bitrate",
+ default=320,
+ type=int,
+ help="Bitrate of converted mp3.")
+ parser.add_argument("--mp3-preset", choices=range(2, 8), type=int, default=2,
+ help="Encoder preset of MP3, 2 for highest quality, 7 for "
+ "fastest speed. Default is 2")
+ parser.add_argument("-j", "--jobs",
+ default=0,
+ type=int,
+ help="Number of jobs. This can increase memory usage but will "
+ "be much faster when multiple cores are available.")
+
+ return parser
+
+
+def main_init(opts=None):
+ parser = get_parser()
+ args = parser.parse_args(opts)
+ args.name = 'mdx_extra'
+ '''
+ print("args.list_models:",args.list_models)
+ if args.list_models:
+ models = list_models(args.repo)
+ print("Bag of models:", end="\n ")
+ print("\n ".join(models["bag"]))
+ print("Single models:", end="\n ")
+ print("\n ".join(models["single"]))
+ sys.exit(0)
+ #'''
+ #if len(args.tracks) == 0:
+ # print("error: the following arguments are required: tracks", file=sys.stderr)
+ # sys.exit(1)
+ #print("args:", args)
+ try:
+ separator = Separator(model=args.name,
+ repo=args.repo,
+ device=args.device,
+ shifts=args.shifts,
+ split=args.split,
+ overlap=args.overlap,
+ progress=True,
+ jobs=args.jobs,
+ segment=args.segment)
+ except ModelLoadingError as error:
+ fatal(error.args[0])
+
+ return separator, args
+
+def main(separator,args):
+ #print(args)
+ max_allowed_segment = float('inf')
+ if isinstance(separator.model, HTDemucs):
+ max_allowed_segment = float(separator.model.segment)
+ elif isinstance(separator.model, BagOfModels):
+ max_allowed_segment = separator.model.max_allowed_segment
+ if args.segment is not None and args.segment > max_allowed_segment:
+ fatal("Cannot use a Transformer model with a longer segment "
+ f"than it was trained for. Maximum segment is: {max_allowed_segment}")
+
+ if isinstance(separator.model, BagOfModels):
+ print(
+ f"Selected model is a bag of {len(separator.model.models)} models. "
+ "You will see that many progress bars per track."
+ )
+
+ if args.stem is not None and args.stem not in separator.model.sources:
+ fatal(
+ 'error: stem "{stem}" is not in selected model. '
+ "STEM must be one of {sources}.".format(
+ stem=args.stem, sources=", ".join(separator.model.sources)
+ )
+ )
+ out = args.out #/ args.name
+ out.mkdir(parents=True, exist_ok=True)
+ print(f"Separated tracks will be stored in {out.resolve()}")
+ for track in args.tracks:
+ if not track.exists():
+ print(f"File {track} does not exist. If the path contains spaces, "
+ 'please try again after surrounding the entire path with quotes "".',
+ file=sys.stderr)
+ continue
+ print(f"Separating track {track}")
+
+ origin, res = separator.separate_audio_file(track)
+
+ if args.mp3:
+ ext = "mp3"
+ elif args.flac:
+ ext = "flac"
+ else:
+ ext = "wav"
+ kwargs = {
+ "samplerate": separator.samplerate,
+ "bitrate": args.mp3_bitrate,
+ "preset": args.mp3_preset,
+ "clip": args.clip_mode,
+ "as_float": args.float32,
+ "bits_per_sample": 24 if args.int24 else 16,
+ }
+ #print("@@@args.stem:", args.stem)
+ if args.stem is None:
+ for name, source in res.items():
+ stem = out / args.filename.format(
+ track=track.name.rsplit(".", 1)[0],
+ trackext=track.name.rsplit(".", 1)[-1],
+ stem=name,
+ ext=ext,
+ )
+ #print("source:",source.shape)
+ stem.parent.mkdir(parents=True, exist_ok=True)
+ #print("@@@@str(stem):", str(stem))
+ save_audio(source, str(stem), **kwargs)
+ else:
+ stem = out / args.filename.format(
+ track=track.name.rsplit(".", 1)[0],
+ trackext=track.name.rsplit(".", 1)[-1],
+ stem="minus_" + args.stem,
+ ext=ext,
+ )
+ if args.other_method == "minus":
+ stem.parent.mkdir(parents=True, exist_ok=True)
+ save_audio(origin - res[args.stem], str(stem), **kwargs)
+ stem = out / args.filename.format(
+ track=track.name.rsplit(".", 1)[0],
+ trackext=track.name.rsplit(".", 1)[-1],
+ stem=args.stem,
+ ext=ext,
+ )
+ stem.parent.mkdir(parents=True, exist_ok=True)
+ save_audio(res.pop(args.stem), str(stem), **kwargs)
+ # Warning : after poping the stem, selected stem is no longer in the dict 'res'
+ if args.other_method == "add":
+ other_stem = th.zeros_like(next(iter(res.values())))
+ for i in res.values():
+ other_stem += i
+ stem = out / args.filename.format(
+ track=track.name.rsplit(".", 1)[0],
+ trackext=track.name.rsplit(".", 1)[-1],
+ stem="no_" + args.stem,
+ ext=ext,
+ )
+ stem.parent.mkdir(parents=True, exist_ok=True)
+ save_audio(other_stem, str(stem), **kwargs)
+ return str(stem)
+
+class main_seperator():
+ def __init__(self, modelname = 'mdx_extra'):
+ # cmd = f'-o {outpath} --filename {out_name} -n {modelname} {in_name}'
+
+ '''
+ @@@out: test_out/mdx_extra
+ @@@args.name: mdx_extra
+ track: [PosixPath('test_wav/千年之恋_2-a.wav')]
+ args.stem: None
+ '''
+ self.separator, self.args = main_init()
+
+ def process_one(self, in_name, out_path):
+ self.args.tracks = [Path(in_name)] ##输入名字
+ self.args.out = Path(os.path.dirname(out_path)) #输出路径
+ self.args.filename = os.path.basename(out_path)#[:-4]+'.wav'#out_basename ##输出名字
+ outpath = main(self.separator, self.args)
+ #outpath = os.path.join(out_path, self.args.name ,self.args.filename)
+ #print("@@@@demucs outpath:", outpath)
+ return outpath
+
+if __name__ == "__main__":
+ in_wav = "test_wav/千年之恋_2-a.wav"
+ out_wav = "test_out/out.wav"
+ seperator = main_seperator()
+ seperator.process_one(in_wav, out_wav)

File Metadata

Mime Type
text/x-diff
Expires
Sun, Jan 12, 08:36 (1 d, 11 h)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
1347225
Default Alt Text
(283 KB)

Event Timeline