diff --git a/AIMeiSheng/RawNet3/__pycache__/__init__.cpython-38.pyc b/AIMeiSheng/RawNet3/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index b977952..0000000 Binary files a/AIMeiSheng/RawNet3/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/AIMeiSheng/RawNet3/__pycache__/cal_cos_distance_folder.cpython-38.pyc b/AIMeiSheng/RawNet3/__pycache__/cal_cos_distance_folder.cpython-38.pyc deleted file mode 100644 index 02da763..0000000 Binary files a/AIMeiSheng/RawNet3/__pycache__/cal_cos_distance_folder.cpython-38.pyc and /dev/null differ diff --git a/AIMeiSheng/RawNet3/__pycache__/cal_cos_distance_folder.cpython-39.pyc b/AIMeiSheng/RawNet3/__pycache__/cal_cos_distance_folder.cpython-39.pyc deleted file mode 100644 index fabcfd3..0000000 Binary files a/AIMeiSheng/RawNet3/__pycache__/cal_cos_distance_folder.cpython-39.pyc and /dev/null differ diff --git a/AIMeiSheng/RawNet3/__pycache__/infererence_fang_meisheng.cpython-38.pyc b/AIMeiSheng/RawNet3/__pycache__/infererence_fang_meisheng.cpython-38.pyc deleted file mode 100644 index 0003ccb..0000000 Binary files a/AIMeiSheng/RawNet3/__pycache__/infererence_fang_meisheng.cpython-38.pyc and /dev/null differ diff --git a/AIMeiSheng/RawNet3/__pycache__/infererence_fang_meisheng.cpython-39.pyc b/AIMeiSheng/RawNet3/__pycache__/infererence_fang_meisheng.cpython-39.pyc deleted file mode 100644 index be84168..0000000 Binary files a/AIMeiSheng/RawNet3/__pycache__/infererence_fang_meisheng.cpython-39.pyc and /dev/null differ diff --git a/AIMeiSheng/RawNet3/__pycache__/multi_threads_wraper.cpython-39.pyc b/AIMeiSheng/RawNet3/__pycache__/multi_threads_wraper.cpython-39.pyc deleted file mode 100644 index ca69a96..0000000 Binary files a/AIMeiSheng/RawNet3/__pycache__/multi_threads_wraper.cpython-39.pyc and /dev/null differ diff --git a/AIMeiSheng/RawNet3/__pycache__/utils.cpython-38.pyc b/AIMeiSheng/RawNet3/__pycache__/utils.cpython-38.pyc deleted file mode 100644 index 1bf1e45..0000000 Binary files a/AIMeiSheng/RawNet3/__pycache__/utils.cpython-38.pyc and /dev/null differ diff --git a/AIMeiSheng/RawNet3/__pycache__/utils.cpython-39.pyc b/AIMeiSheng/RawNet3/__pycache__/utils.cpython-39.pyc deleted file mode 100644 index 231937c..0000000 Binary files a/AIMeiSheng/RawNet3/__pycache__/utils.cpython-39.pyc and /dev/null differ diff --git a/AIMeiSheng/RawNet3/models/__pycache__/RawNet3.cpython-38.pyc b/AIMeiSheng/RawNet3/models/__pycache__/RawNet3.cpython-38.pyc deleted file mode 100644 index e5499c1..0000000 Binary files a/AIMeiSheng/RawNet3/models/__pycache__/RawNet3.cpython-38.pyc and /dev/null differ diff --git a/AIMeiSheng/RawNet3/models/__pycache__/RawNet3.cpython-39.pyc b/AIMeiSheng/RawNet3/models/__pycache__/RawNet3.cpython-39.pyc deleted file mode 100644 index 643c555..0000000 Binary files a/AIMeiSheng/RawNet3/models/__pycache__/RawNet3.cpython-39.pyc and /dev/null differ diff --git a/AIMeiSheng/RawNet3/models/__pycache__/RawNetBasicBlock.cpython-38.pyc b/AIMeiSheng/RawNet3/models/__pycache__/RawNetBasicBlock.cpython-38.pyc deleted file mode 100644 index a29ed34..0000000 Binary files a/AIMeiSheng/RawNet3/models/__pycache__/RawNetBasicBlock.cpython-38.pyc and /dev/null differ diff --git a/AIMeiSheng/RawNet3/models/__pycache__/RawNetBasicBlock.cpython-39.pyc b/AIMeiSheng/RawNet3/models/__pycache__/RawNetBasicBlock.cpython-39.pyc deleted file mode 100644 index cffbcf6..0000000 Binary files a/AIMeiSheng/RawNet3/models/__pycache__/RawNetBasicBlock.cpython-39.pyc and /dev/null differ diff --git a/AIMeiSheng/RawNet3/models/__pycache__/__init__.cpython-38.pyc b/AIMeiSheng/RawNet3/models/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index f00a24c..0000000 Binary files a/AIMeiSheng/RawNet3/models/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/AIMeiSheng/RawNet3/models/__pycache__/__init__.cpython-39.pyc b/AIMeiSheng/RawNet3/models/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index a0da605..0000000 Binary files a/AIMeiSheng/RawNet3/models/__pycache__/__init__.cpython-39.pyc and /dev/null differ diff --git a/AIMeiSheng/diffuse_fang/__init__.py b/AIMeiSheng/diffuse_fang/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/AIMeiSheng/diffuse_fang/diffUse_fang.py b/AIMeiSheng/diffuse_fang/diffUse_fang.py new file mode 100644 index 0000000..b216924 --- /dev/null +++ b/AIMeiSheng/diffuse_fang/diffUse_fang.py @@ -0,0 +1,42 @@ +from diffusion.wavenet import WaveNet +from diffusion.diffusion import GaussianDiffusion + +import torch +out_dims = 192#128 ##决定输出维度 +n_layers=20 +n_chans=384 +n_hidden=128#256 ###决定输入维度 +timesteps=1000 +k_step_max=1000 + +###out: B x n_frames x feat, 推理的话returrn 目标数据,训练的时候return 是 mse loss +##GaussianDiffusion 我做了更改推理的时候范围预测结果(1个),训练时候返回loss和重构预测的特征(2个) +diff_decoder = GaussianDiffusion(WaveNet(out_dims, n_layers, n_chans, n_hidden),timesteps=timesteps,k_step=k_step_max, out_dims=out_dims) + +gt_spec=None#这个是x0的数据,推理不需要,测试需要 +infer=True # train的时候设置成Fasle +infer_speedup=10 +method='dpm-solver' +k_step=100 +use_tqdm=True + +if __name__ == "__main__": + + B = 32 + n_frames = 120 + n_unit = n_hidden + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + diff_decoder = diff_decoder.to(device) + x = torch.randn(B, n_frames,n_unit).to(device) ##input: B x n_frames x n_unit + print("@@@ input x shape:", x.shape) + # 生成标签数据(假设简单线性分类) + # Y = torch.randint(0, 2, (num_samples, output_dim)).float() + #gt_spec在训练的时候是label,infer的时候是None + #x = x.half() + #diff_decoder = diff_decoder.half() + out = diff_decoder(x, gt_spec=gt_spec, infer=infer, infer_speedup=infer_speedup, method=method, k_step=k_step, + use_tqdm=use_tqdm) + print("@@@ out shape:",out.shape) #torch.Size([32, 120, 128]) ###out: B x n_frames x feat + print("out:",out[0,0,:]) + diff --git a/AIMeiSheng/diffuse_fang/diffUse_wraper.py b/AIMeiSheng/diffuse_fang/diffUse_wraper.py new file mode 100644 index 0000000..87a8889 --- /dev/null +++ b/AIMeiSheng/diffuse_fang/diffUse_wraper.py @@ -0,0 +1,59 @@ +from diffuse_fang.diffusion.wavenet import WaveNet +from diffuse_fang.diffusion.diffusion import GaussianDiffusion + +import torch + +out_dims = 192 ##决定输出维度 +n_layers=20 +n_chans=384 +n_hidden=192#256 ##决定输入维度 +timesteps=1000 +k_step_max=1000 + + +#class WaveNet(nn.Module): +# def __init__(self, in_dims=128, n_layers=20, n_chans=384, n_hidden=256): + +###out: B x n_frames x feat, 推理的话returrn 目标数据,训练的时候return 是 mse loss +#input size +#output size: +diff_decoder = GaussianDiffusion(WaveNet(out_dims, n_layers, n_chans, n_hidden),timesteps=timesteps,k_step=k_step_max, out_dims=out_dims) + +''' +gt_spec=None#这个是x0的数据,推理不需要,测试需要 +infer=True # train的时候设置成Fasle +infer_speedup=10 +method='dpm-solver' +k_step=100 +use_tqdm=True +#''' + +class ddpm_para(): + def __init__(self, gt_spec=None,infer=True,infer_speedup=10,method='dpm-solver',k_step=100,use_tqdm = True): + #self.use_tqdm = use_tqdm #True + self.gt_spec = gt_spec#None#这个是x0的数据,推理不需要,测试需要 + self.infer = infer #True # train的时候设置成Fasle + self.infer_speedup = infer_speedup#10 + self.method = method #'dpm-solver' + self.k_step = k_step + self.use_tqdm = use_tqdm + + +if __name__ == "__main__": + ddpm_dp = ddpm_para() + + B = 32 + n_frames = 120 + n_unit = 192 + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + diff_decoder = diff_decoder.to(device) + x = torch.randn(B, n_frames,n_unit).to(device) ##input: B x n_frames x n_unit + print("@@@ input x shape:", x.shape) + # 生成标签数据(假设简单线性分类) + # Y = torch.randint(0, 2, (num_samples, output_dim)).float() + + out = diff_decoder(x, gt_spec=ddpm_dp.gt_spec, infer=ddpm_dp.infer, infer_speedup=ddpm_dp.infer_speedup, method=ddpm_dp.method, k_step=ddpm_dp.k_step, use_tqdm=ddpm_dp.use_tqdm) + print("@@@ out shape:",out.shape) #torch.Size([32, 120, 128]) ###out: B x n_frames x feat + + diff --git a/AIMeiSheng/diffuse_fang/diffusion/__init__.py b/AIMeiSheng/diffuse_fang/diffusion/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/AIMeiSheng/diffuse_fang/diffusion/data_loaders.py b/AIMeiSheng/diffuse_fang/diffusion/data_loaders.py new file mode 100644 index 0000000..9f00b9a --- /dev/null +++ b/AIMeiSheng/diffuse_fang/diffusion/data_loaders.py @@ -0,0 +1,288 @@ +import os +import random + +import librosa +import numpy as np +import torch +from torch.utils.data import Dataset +from tqdm import tqdm + +from utils import repeat_expand_2d + + +def traverse_dir( + root_dir, + extensions, + amount=None, + str_include=None, + str_exclude=None, + is_pure=False, + is_sort=False, + is_ext=True): + + file_list = [] + cnt = 0 + for root, _, files in os.walk(root_dir): + for file in files: + if any([file.endswith(f".{ext}") for ext in extensions]): + # path + mix_path = os.path.join(root, file) + pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path + + # amount + if (amount is not None) and (cnt == amount): + if is_sort: + file_list.sort() + return file_list + + # check string + if (str_include is not None) and (str_include not in pure_path): + continue + if (str_exclude is not None) and (str_exclude in pure_path): + continue + + if not is_ext: + ext = pure_path.split('.')[-1] + pure_path = pure_path[:-(len(ext)+1)] + file_list.append(pure_path) + cnt += 1 + if is_sort: + file_list.sort() + return file_list + + +def get_data_loaders(args, whole_audio=False): + data_train = AudioDataset( + filelists = args.data.training_files, + waveform_sec=args.data.duration, + hop_size=args.data.block_size, + sample_rate=args.data.sampling_rate, + load_all_data=args.train.cache_all_data, + whole_audio=whole_audio, + extensions=args.data.extensions, + n_spk=args.model.n_spk, + spk=args.spk, + device=args.train.cache_device, + fp16=args.train.cache_fp16, + unit_interpolate_mode = args.data.unit_interpolate_mode, + use_aug=True) + loader_train = torch.utils.data.DataLoader( + data_train , + batch_size=args.train.batch_size if not whole_audio else 1, + shuffle=True, + num_workers=args.train.num_workers if args.train.cache_device=='cpu' else 0, + persistent_workers=(args.train.num_workers > 0) if args.train.cache_device=='cpu' else False, + pin_memory=True if args.train.cache_device=='cpu' else False + ) + data_valid = AudioDataset( + filelists = args.data.validation_files, + waveform_sec=args.data.duration, + hop_size=args.data.block_size, + sample_rate=args.data.sampling_rate, + load_all_data=args.train.cache_all_data, + whole_audio=True, + spk=args.spk, + extensions=args.data.extensions, + unit_interpolate_mode = args.data.unit_interpolate_mode, + n_spk=args.model.n_spk) + loader_valid = torch.utils.data.DataLoader( + data_valid, + batch_size=1, + shuffle=False, + num_workers=0, + pin_memory=True + ) + return loader_train, loader_valid + + +class AudioDataset(Dataset): + def __init__( + self, + filelists, + waveform_sec, + hop_size, + sample_rate, + spk, + load_all_data=True, + whole_audio=False, + extensions=['wav'], + n_spk=1, + device='cpu', + fp16=False, + use_aug=False, + unit_interpolate_mode = 'left' + ): + super().__init__() + + self.waveform_sec = waveform_sec + self.sample_rate = sample_rate + self.hop_size = hop_size + self.filelists = filelists + self.whole_audio = whole_audio + self.use_aug = use_aug + self.data_buffer={} + self.pitch_aug_dict = {} + self.unit_interpolate_mode = unit_interpolate_mode + # np.load(os.path.join(self.path_root, 'pitch_aug_dict.npy'), allow_pickle=True).item() + if load_all_data: + print('Load all the data filelists:', filelists) + else: + print('Load the f0, volume data filelists:', filelists) + with open(filelists,"r") as f: + self.paths = f.read().splitlines() + for name_ext in tqdm(self.paths, total=len(self.paths)): + path_audio = name_ext + duration = librosa.get_duration(filename = path_audio, sr = self.sample_rate) + + path_f0 = name_ext + ".f0.npy" + f0,_ = np.load(path_f0,allow_pickle=True) + f0 = torch.from_numpy(np.array(f0,dtype=float)).float().unsqueeze(-1).to(device) + + path_volume = name_ext + ".vol.npy" + volume = np.load(path_volume) + volume = torch.from_numpy(volume).float().unsqueeze(-1).to(device) + + path_augvol = name_ext + ".aug_vol.npy" + aug_vol = np.load(path_augvol) + aug_vol = torch.from_numpy(aug_vol).float().unsqueeze(-1).to(device) + + if n_spk is not None and n_spk > 1: + spk_name = name_ext.split("/")[-2] + spk_id = spk[spk_name] if spk_name in spk else 0 + if spk_id < 0 or spk_id >= n_spk: + raise ValueError(' [x] Muiti-speaker traing error : spk_id must be a positive integer from 0 to n_spk-1 ') + else: + spk_id = 0 + spk_id = torch.LongTensor(np.array([spk_id])).to(device) + + if load_all_data: + ''' + audio, sr = librosa.load(path_audio, sr=self.sample_rate) + if len(audio.shape) > 1: + audio = librosa.to_mono(audio) + audio = torch.from_numpy(audio).to(device) + ''' + path_mel = name_ext + ".mel.npy" + mel = np.load(path_mel) + mel = torch.from_numpy(mel).to(device) + + path_augmel = name_ext + ".aug_mel.npy" + aug_mel,keyshift = np.load(path_augmel, allow_pickle=True) + aug_mel = np.array(aug_mel,dtype=float) + aug_mel = torch.from_numpy(aug_mel).to(device) + self.pitch_aug_dict[name_ext] = keyshift + + path_units = name_ext + ".soft.pt" + units = torch.load(path_units).to(device) + units = units[0] + units = repeat_expand_2d(units,f0.size(0),unit_interpolate_mode).transpose(0,1) + + if fp16: + mel = mel.half() + aug_mel = aug_mel.half() + units = units.half() + + self.data_buffer[name_ext] = { + 'duration': duration, + 'mel': mel, + 'aug_mel': aug_mel, + 'units': units, + 'f0': f0, + 'volume': volume, + 'aug_vol': aug_vol, + 'spk_id': spk_id + } + else: + path_augmel = name_ext + ".aug_mel.npy" + aug_mel,keyshift = np.load(path_augmel, allow_pickle=True) + self.pitch_aug_dict[name_ext] = keyshift + self.data_buffer[name_ext] = { + 'duration': duration, + 'f0': f0, + 'volume': volume, + 'aug_vol': aug_vol, + 'spk_id': spk_id + } + + + def __getitem__(self, file_idx): + name_ext = self.paths[file_idx] + data_buffer = self.data_buffer[name_ext] + # check duration. if too short, then skip + if data_buffer['duration'] < (self.waveform_sec + 0.1): + return self.__getitem__( (file_idx + 1) % len(self.paths)) + + # get item + return self.get_data(name_ext, data_buffer) + + def get_data(self, name_ext, data_buffer): + name = os.path.splitext(name_ext)[0] + frame_resolution = self.hop_size / self.sample_rate + duration = data_buffer['duration'] + waveform_sec = duration if self.whole_audio else self.waveform_sec + + # load audio + idx_from = 0 if self.whole_audio else random.uniform(0, duration - waveform_sec - 0.1) + start_frame = int(idx_from / frame_resolution) + units_frame_len = int(waveform_sec / frame_resolution) + aug_flag = random.choice([True, False]) and self.use_aug + ''' + audio = data_buffer.get('audio') + if audio is None: + path_audio = os.path.join(self.path_root, 'audio', name) + '.wav' + audio, sr = librosa.load( + path_audio, + sr = self.sample_rate, + offset = start_frame * frame_resolution, + duration = waveform_sec) + if len(audio.shape) > 1: + audio = librosa.to_mono(audio) + # clip audio into N seconds + audio = audio[ : audio.shape[-1] // self.hop_size * self.hop_size] + audio = torch.from_numpy(audio).float() + else: + audio = audio[start_frame * self.hop_size : (start_frame + units_frame_len) * self.hop_size] + ''' + # load mel + mel_key = 'aug_mel' if aug_flag else 'mel' + mel = data_buffer.get(mel_key) + if mel is None: + mel = name_ext + ".mel.npy" + mel = np.load(mel) + mel = mel[start_frame : start_frame + units_frame_len] + mel = torch.from_numpy(mel).float() + else: + mel = mel[start_frame : start_frame + units_frame_len] + + # load f0 + f0 = data_buffer.get('f0') + aug_shift = 0 + if aug_flag: + aug_shift = self.pitch_aug_dict[name_ext] + f0_frames = 2 ** (aug_shift / 12) * f0[start_frame : start_frame + units_frame_len] + + # load units + units = data_buffer.get('units') + if units is None: + path_units = name_ext + ".soft.pt" + units = torch.load(path_units) + units = units[0] + units = repeat_expand_2d(units,f0.size(0),self.unit_interpolate_mode).transpose(0,1) + + units = units[start_frame : start_frame + units_frame_len] + + # load volume + vol_key = 'aug_vol' if aug_flag else 'volume' + volume = data_buffer.get(vol_key) + volume_frames = volume[start_frame : start_frame + units_frame_len] + + # load spk_id + spk_id = data_buffer.get('spk_id') + + # load shift + aug_shift = torch.from_numpy(np.array([[aug_shift]])).float() + + return dict(mel=mel, f0=f0_frames, volume=volume_frames, units=units, spk_id=spk_id, aug_shift=aug_shift, name=name, name_ext=name_ext) + + def __len__(self): + return len(self.paths) \ No newline at end of file diff --git a/AIMeiSheng/diffuse_fang/diffusion/diffusion.py b/AIMeiSheng/diffuse_fang/diffusion/diffusion.py new file mode 100644 index 0000000..edb3be5 --- /dev/null +++ b/AIMeiSheng/diffuse_fang/diffusion/diffusion.py @@ -0,0 +1,398 @@ +from collections import deque +from functools import partial +from inspect import isfunction + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from tqdm import tqdm + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def extract(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def noise_like(shape, device, repeat=False): + def repeat_noise(): + return torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + def noise(): + return torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() + + +def linear_beta_schedule(timesteps, max_beta=0.02): + """ + linear schedule + """ + betas = np.linspace(1e-4, max_beta, timesteps) + return betas + + +def cosine_beta_schedule(timesteps, s=0.008): + """ + cosine schedule + as proposed in https://openreview.net/forum?id=-NEXDKk8gZ + """ + steps = timesteps + 1 + x = np.linspace(0, steps, steps) + alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2 + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] + betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) + return np.clip(betas, a_min=0, a_max=0.999) + + +beta_schedule = { + "cosine": cosine_beta_schedule, + "linear": linear_beta_schedule, +} + + +class GaussianDiffusion(nn.Module): + def __init__(self, + denoise_fn, + out_dims=128, + timesteps=1000, + k_step=1000, + max_beta=0.02, + spec_min=-12, + spec_max=2): + + super().__init__() + self.denoise_fn = denoise_fn + self.out_dims = out_dims + betas = beta_schedule['linear'](timesteps, max_beta=max_beta) + + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.k_step = k_step if k_step>0 and k_step 1: + if method == 'dpm-solver' or method == 'dpm-solver++': + from .dpm_solver_pytorch import ( + DPM_Solver, + NoiseScheduleVP, + model_wrapper, + ) + # 1. Define the noise schedule. + noise_schedule = NoiseScheduleVP(schedule='discrete', betas=self.betas[:t]) + + # 2. Convert your discrete-time `model` to the continuous-time + # noise prediction model. Here is an example for a diffusion model + # `model` with the noise prediction type ("noise") . + def my_wrapper(fn): + def wrapped(x, t, **kwargs): + ret = fn(x, t, **kwargs) + if use_tqdm: + self.bar.update(1) + return ret + + return wrapped + + model_fn = model_wrapper( + my_wrapper(self.denoise_fn), + noise_schedule, + model_type="noise", # or "x_start" or "v" or "score" + model_kwargs={"cond": cond} + ) + + # 3. Define dpm-solver and sample by singlestep DPM-Solver. + # (We recommend singlestep DPM-Solver for unconditional sampling) + # You can adjust the `steps` to balance the computation + # costs and the sample quality. + if method == 'dpm-solver': + dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver") + elif method == 'dpm-solver++': + dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") + + steps = t // infer_speedup + if use_tqdm: + self.bar = tqdm(desc="sample time step", total=steps) + x = dpm_solver.sample( + x, + steps=steps, + order=2, + skip_type="time_uniform", + method="multistep", + ) + if use_tqdm: + self.bar.close() + elif method == 'pndm': + self.noise_list = deque(maxlen=4) + if use_tqdm: + for i in tqdm( + reversed(range(0, t, infer_speedup)), desc='sample time step', + total=t // infer_speedup, + ): + x = self.p_sample_plms( + x, torch.full((b,), i, device=device, dtype=torch.long), + infer_speedup, cond=cond + ) + else: + for i in reversed(range(0, t, infer_speedup)): + x = self.p_sample_plms( + x, torch.full((b,), i, device=device, dtype=torch.long), + infer_speedup, cond=cond + ) + elif method == 'ddim': + if use_tqdm: + for i in tqdm( + reversed(range(0, t, infer_speedup)), desc='sample time step', + total=t // infer_speedup, + ): + x = self.p_sample_ddim( + x, torch.full((b,), i, device=device, dtype=torch.long), + infer_speedup, cond=cond + ) + else: + for i in reversed(range(0, t, infer_speedup)): + x = self.p_sample_ddim( + x, torch.full((b,), i, device=device, dtype=torch.long), + infer_speedup, cond=cond + ) + elif method == 'unipc': + from .uni_pc import NoiseScheduleVP, UniPC, model_wrapper + # 1. Define the noise schedule. + noise_schedule = NoiseScheduleVP(schedule='discrete', betas=self.betas[:t]) + + # 2. Convert your discrete-time `model` to the continuous-time + # noise prediction model. Here is an example for a diffusion model + # `model` with the noise prediction type ("noise") . + def my_wrapper(fn): + def wrapped(x, t, **kwargs): + ret = fn(x, t, **kwargs) + if use_tqdm: + self.bar.update(1) + return ret + + return wrapped + + model_fn = model_wrapper( + my_wrapper(self.denoise_fn), + noise_schedule, + model_type="noise", # or "x_start" or "v" or "score" + model_kwargs={"cond": cond} + ) + + # 3. Define uni_pc and sample by multistep UniPC. + # You can adjust the `steps` to balance the computation + # costs and the sample quality. + uni_pc = UniPC(model_fn, noise_schedule, variant='bh2') + + steps = t // infer_speedup + if use_tqdm: + self.bar = tqdm(desc="sample time step", total=steps) + x = uni_pc.sample( + x, + steps=steps, + order=2, + skip_type="time_uniform", + method="multistep", + ) + if use_tqdm: + self.bar.close() + else: + raise NotImplementedError(method) + else: + if use_tqdm: + for i in tqdm(reversed(range(0, t)), desc='sample time step', total=t): + x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond) + else: + for i in reversed(range(0, t)): + x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond) + x = x.squeeze(1).transpose(1, 2) # [B, T, M] + return self.denorm_spec(x) + + def norm_spec(self, x): + return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1 + + def denorm_spec(self, x): + return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min diff --git a/AIMeiSheng/diffuse_fang/diffusion/diffusion_onnx.py b/AIMeiSheng/diffuse_fang/diffusion/diffusion_onnx.py new file mode 100644 index 0000000..f01e463 --- /dev/null +++ b/AIMeiSheng/diffuse_fang/diffusion/diffusion_onnx.py @@ -0,0 +1,614 @@ +import math +from collections import deque +from functools import partial +from inspect import isfunction + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import Conv1d, Mish +from tqdm import tqdm + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def extract(a, t): + return a[t].reshape((1, 1, 1, 1)) + + +def noise_like(shape, device, repeat=False): + def repeat_noise(): + return torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + def noise(): + return torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() + + +def linear_beta_schedule(timesteps, max_beta=0.02): + """ + linear schedule + """ + betas = np.linspace(1e-4, max_beta, timesteps) + return betas + + +def cosine_beta_schedule(timesteps, s=0.008): + """ + cosine schedule + as proposed in https://openreview.net/forum?id=-NEXDKk8gZ + """ + steps = timesteps + 1 + x = np.linspace(0, steps, steps) + alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2 + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] + betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) + return np.clip(betas, a_min=0, a_max=0.999) + + +beta_schedule = { + "cosine": cosine_beta_schedule, + "linear": linear_beta_schedule, +} + + +def extract_1(a, t): + return a[t].reshape((1, 1, 1, 1)) + + +def predict_stage0(noise_pred, noise_pred_prev): + return (noise_pred + noise_pred_prev) / 2 + + +def predict_stage1(noise_pred, noise_list): + return (noise_pred * 3 + - noise_list[-1]) / 2 + + +def predict_stage2(noise_pred, noise_list): + return (noise_pred * 23 + - noise_list[-1] * 16 + + noise_list[-2] * 5) / 12 + + +def predict_stage3(noise_pred, noise_list): + return (noise_pred * 55 + - noise_list[-1] * 59 + + noise_list[-2] * 37 + - noise_list[-3] * 9) / 24 + + +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + self.half_dim = dim // 2 + self.emb = 9.21034037 / (self.half_dim - 1) + self.emb = torch.exp(torch.arange(self.half_dim) * torch.tensor(-self.emb)).unsqueeze(0) + self.emb = self.emb.cpu() + + def forward(self, x): + emb = self.emb * x + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class ResidualBlock(nn.Module): + def __init__(self, encoder_hidden, residual_channels, dilation): + super().__init__() + self.residual_channels = residual_channels + self.dilated_conv = Conv1d(residual_channels, 2 * residual_channels, 3, padding=dilation, dilation=dilation) + self.diffusion_projection = nn.Linear(residual_channels, residual_channels) + self.conditioner_projection = Conv1d(encoder_hidden, 2 * residual_channels, 1) + self.output_projection = Conv1d(residual_channels, 2 * residual_channels, 1) + + def forward(self, x, conditioner, diffusion_step): + diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1) + conditioner = self.conditioner_projection(conditioner) + y = x + diffusion_step + y = self.dilated_conv(y) + conditioner + + gate, filter_1 = torch.split(y, [self.residual_channels, self.residual_channels], dim=1) + + y = torch.sigmoid(gate) * torch.tanh(filter_1) + y = self.output_projection(y) + + residual, skip = torch.split(y, [self.residual_channels, self.residual_channels], dim=1) + + return (x + residual) / 1.41421356, skip + + +class DiffNet(nn.Module): + def __init__(self, in_dims, n_layers, n_chans, n_hidden): + super().__init__() + self.encoder_hidden = n_hidden + self.residual_layers = n_layers + self.residual_channels = n_chans + self.input_projection = Conv1d(in_dims, self.residual_channels, 1) + self.diffusion_embedding = SinusoidalPosEmb(self.residual_channels) + dim = self.residual_channels + self.mlp = nn.Sequential( + nn.Linear(dim, dim * 4), + Mish(), + nn.Linear(dim * 4, dim) + ) + self.residual_layers = nn.ModuleList([ + ResidualBlock(self.encoder_hidden, self.residual_channels, 1) + for i in range(self.residual_layers) + ]) + self.skip_projection = Conv1d(self.residual_channels, self.residual_channels, 1) + self.output_projection = Conv1d(self.residual_channels, in_dims, 1) + nn.init.zeros_(self.output_projection.weight) + + def forward(self, spec, diffusion_step, cond): + x = spec.squeeze(0) + x = self.input_projection(x) # x [B, residual_channel, T] + x = F.relu(x) + # skip = torch.randn_like(x) + diffusion_step = diffusion_step.float() + diffusion_step = self.diffusion_embedding(diffusion_step) + diffusion_step = self.mlp(diffusion_step) + + x, skip = self.residual_layers[0](x, cond, diffusion_step) + # noinspection PyTypeChecker + for layer in self.residual_layers[1:]: + x, skip_connection = layer.forward(x, cond, diffusion_step) + skip = skip + skip_connection + x = skip / math.sqrt(len(self.residual_layers)) + x = self.skip_projection(x) + x = F.relu(x) + x = self.output_projection(x) # [B, 80, T] + return x.unsqueeze(1) + + +class AfterDiffusion(nn.Module): + def __init__(self, spec_max, spec_min, v_type='a'): + super().__init__() + self.spec_max = spec_max + self.spec_min = spec_min + self.type = v_type + + def forward(self, x): + x = x.squeeze(1).permute(0, 2, 1) + mel_out = (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min + if self.type == 'nsf-hifigan-log10': + mel_out = mel_out * 0.434294 + return mel_out.transpose(2, 1) + + +class Pred(nn.Module): + def __init__(self, alphas_cumprod): + super().__init__() + self.alphas_cumprod = alphas_cumprod + + def forward(self, x_1, noise_t, t_1, t_prev): + a_t = extract(self.alphas_cumprod, t_1).cpu() + a_prev = extract(self.alphas_cumprod, t_prev).cpu() + a_t_sq, a_prev_sq = a_t.sqrt().cpu(), a_prev.sqrt().cpu() + x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x_1 - 1 / ( + a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t) + x_pred = x_1 + x_delta.cpu() + + return x_pred + + +class GaussianDiffusion(nn.Module): + def __init__(self, + out_dims=128, + n_layers=20, + n_chans=384, + n_hidden=256, + timesteps=1000, + k_step=1000, + max_beta=0.02, + spec_min=-12, + spec_max=2): + super().__init__() + self.denoise_fn = DiffNet(out_dims, n_layers, n_chans, n_hidden) + self.out_dims = out_dims + self.mel_bins = out_dims + self.n_hidden = n_hidden + betas = beta_schedule['linear'](timesteps, max_beta=max_beta) + + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.k_step = k_step + + self.noise_list = deque(maxlen=4) + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer('posterior_variance', to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer('posterior_mean_coef1', to_torch( + betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) + self.register_buffer('posterior_mean_coef2', to_torch( + (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) + + self.register_buffer('spec_min', torch.FloatTensor([spec_min])[None, None, :out_dims]) + self.register_buffer('spec_max', torch.FloatTensor([spec_max])[None, None, :out_dims]) + self.ad = AfterDiffusion(self.spec_max, self.spec_min) + self.xp = Pred(self.alphas_cumprod) + + def q_mean_variance(self, x_start, t): + mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = extract(1. - self.alphas_cumprod, t, x_start.shape) + log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, cond): + noise_pred = self.denoise_fn(x, t, cond=cond) + x_recon = self.predict_start_from_noise(x, t=t, noise=noise_pred) + + x_recon.clamp_(-1., 1.) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, cond, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, cond=cond) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_plms(self, x, t, interval, cond, clip_denoised=True, repeat_noise=False): + """ + Use the PLMS method from + [Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778). + """ + + def get_x_pred(x, noise_t, t): + a_t = extract(self.alphas_cumprod, t) + a_prev = extract(self.alphas_cumprod, torch.max(t - interval, torch.zeros_like(t))) + a_t_sq, a_prev_sq = a_t.sqrt(), a_prev.sqrt() + + x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x - 1 / ( + a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t) + x_pred = x + x_delta + + return x_pred + + noise_list = self.noise_list + noise_pred = self.denoise_fn(x, t, cond=cond) + + if len(noise_list) == 0: + x_pred = get_x_pred(x, noise_pred, t) + noise_pred_prev = self.denoise_fn(x_pred, max(t - interval, 0), cond=cond) + noise_pred_prime = (noise_pred + noise_pred_prev) / 2 + elif len(noise_list) == 1: + noise_pred_prime = (3 * noise_pred - noise_list[-1]) / 2 + elif len(noise_list) == 2: + noise_pred_prime = (23 * noise_pred - 16 * noise_list[-1] + 5 * noise_list[-2]) / 12 + else: + noise_pred_prime = (55 * noise_pred - 59 * noise_list[-1] + 37 * noise_list[-2] - 9 * noise_list[-3]) / 24 + + x_prev = get_x_pred(x, noise_pred_prime, t) + noise_list.append(noise_pred) + + return x_prev + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return ( + extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + def p_losses(self, x_start, t, cond, noise=None, loss_type='l2'): + noise = default(noise, lambda: torch.randn_like(x_start)) + + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + x_recon = self.denoise_fn(x_noisy, t, cond) + + if loss_type == 'l1': + loss = (noise - x_recon).abs().mean() + elif loss_type == 'l2': + loss = F.mse_loss(noise, x_recon) + else: + raise NotImplementedError() + + return loss + + def org_forward(self, + condition, + init_noise=None, + gt_spec=None, + infer=True, + infer_speedup=100, + method='pndm', + k_step=1000, + use_tqdm=True): + """ + conditioning diffusion, use fastspeech2 encoder output as the condition + """ + cond = condition + b, device = condition.shape[0], condition.device + if not infer: + spec = self.norm_spec(gt_spec) + t = torch.randint(0, self.k_step, (b,), device=device).long() + norm_spec = spec.transpose(1, 2)[:, None, :, :] # [B, 1, M, T] + return self.p_losses(norm_spec, t, cond=cond) + else: + shape = (cond.shape[0], 1, self.out_dims, cond.shape[2]) + + if gt_spec is None: + t = self.k_step + if init_noise is None: + x = torch.randn(shape, device=device) + else: + x = init_noise + else: + t = k_step + norm_spec = self.norm_spec(gt_spec) + norm_spec = norm_spec.transpose(1, 2)[:, None, :, :] + x = self.q_sample(x_start=norm_spec, t=torch.tensor([t - 1], device=device).long()) + + if method is not None and infer_speedup > 1: + if method == 'dpm-solver': + from .dpm_solver_pytorch import ( + DPM_Solver, + NoiseScheduleVP, + model_wrapper, + ) + # 1. Define the noise schedule. + noise_schedule = NoiseScheduleVP(schedule='discrete', betas=self.betas[:t]) + + # 2. Convert your discrete-time `model` to the continuous-time + # noise prediction model. Here is an example for a diffusion model + # `model` with the noise prediction type ("noise") . + def my_wrapper(fn): + def wrapped(x, t, **kwargs): + ret = fn(x, t, **kwargs) + if use_tqdm: + self.bar.update(1) + return ret + + return wrapped + + model_fn = model_wrapper( + my_wrapper(self.denoise_fn), + noise_schedule, + model_type="noise", # or "x_start" or "v" or "score" + model_kwargs={"cond": cond} + ) + + # 3. Define dpm-solver and sample by singlestep DPM-Solver. + # (We recommend singlestep DPM-Solver for unconditional sampling) + # You can adjust the `steps` to balance the computation + # costs and the sample quality. + dpm_solver = DPM_Solver(model_fn, noise_schedule) + + steps = t // infer_speedup + if use_tqdm: + self.bar = tqdm(desc="sample time step", total=steps) + x = dpm_solver.sample( + x, + steps=steps, + order=3, + skip_type="time_uniform", + method="singlestep", + ) + if use_tqdm: + self.bar.close() + elif method == 'pndm': + self.noise_list = deque(maxlen=4) + if use_tqdm: + for i in tqdm( + reversed(range(0, t, infer_speedup)), desc='sample time step', + total=t // infer_speedup, + ): + x = self.p_sample_plms( + x, torch.full((b,), i, device=device, dtype=torch.long), + infer_speedup, cond=cond + ) + else: + for i in reversed(range(0, t, infer_speedup)): + x = self.p_sample_plms( + x, torch.full((b,), i, device=device, dtype=torch.long), + infer_speedup, cond=cond + ) + else: + raise NotImplementedError(method) + else: + if use_tqdm: + for i in tqdm(reversed(range(0, t)), desc='sample time step', total=t): + x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond) + else: + for i in reversed(range(0, t)): + x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond) + x = x.squeeze(1).transpose(1, 2) # [B, T, M] + return self.denorm_spec(x).transpose(2, 1) + + def norm_spec(self, x): + return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1 + + def denorm_spec(self, x): + return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min + + def get_x_pred(self, x_1, noise_t, t_1, t_prev): + a_t = extract(self.alphas_cumprod, t_1) + a_prev = extract(self.alphas_cumprod, t_prev) + a_t_sq, a_prev_sq = a_t.sqrt(), a_prev.sqrt() + x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x_1 - 1 / ( + a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t) + x_pred = x_1 + x_delta + return x_pred + + def OnnxExport(self, project_name=None, init_noise=None, hidden_channels=256, export_denoise=True, export_pred=True, export_after=True): + cond = torch.randn([1, self.n_hidden, 10]).cpu() + if init_noise is None: + x = torch.randn((1, 1, self.mel_bins, cond.shape[2]), dtype=torch.float32).cpu() + else: + x = init_noise + pndms = 100 + + org_y_x = self.org_forward(cond, init_noise=x) + + device = cond.device + n_frames = cond.shape[2] + step_range = torch.arange(0, self.k_step, pndms, dtype=torch.long, device=device).flip(0) + plms_noise_stage = torch.tensor(0, dtype=torch.long, device=device) + noise_list = torch.zeros((0, 1, 1, self.mel_bins, n_frames), device=device) + + ot = step_range[0] + ot_1 = torch.full((1,), ot, device=device, dtype=torch.long) + if export_denoise: + torch.onnx.export( + self.denoise_fn, + (x.cpu(), ot_1.cpu(), cond.cpu()), + f"{project_name}_denoise.onnx", + input_names=["noise", "time", "condition"], + output_names=["noise_pred"], + dynamic_axes={ + "noise": [3], + "condition": [2] + }, + opset_version=16 + ) + + for t in step_range: + t_1 = torch.full((1,), t, device=device, dtype=torch.long) + noise_pred = self.denoise_fn(x, t_1, cond) + t_prev = t_1 - pndms + t_prev = t_prev * (t_prev > 0) + if plms_noise_stage == 0: + if export_pred: + torch.onnx.export( + self.xp, + (x.cpu(), noise_pred.cpu(), t_1.cpu(), t_prev.cpu()), + f"{project_name}_pred.onnx", + input_names=["noise", "noise_pred", "time", "time_prev"], + output_names=["noise_pred_o"], + dynamic_axes={ + "noise": [3], + "noise_pred": [3] + }, + opset_version=16 + ) + + x_pred = self.get_x_pred(x, noise_pred, t_1, t_prev) + noise_pred_prev = self.denoise_fn(x_pred, t_prev, cond=cond) + noise_pred_prime = predict_stage0(noise_pred, noise_pred_prev) + + elif plms_noise_stage == 1: + noise_pred_prime = predict_stage1(noise_pred, noise_list) + + elif plms_noise_stage == 2: + noise_pred_prime = predict_stage2(noise_pred, noise_list) + + else: + noise_pred_prime = predict_stage3(noise_pred, noise_list) + + noise_pred = noise_pred.unsqueeze(0) + + if plms_noise_stage < 3: + noise_list = torch.cat((noise_list, noise_pred), dim=0) + plms_noise_stage = plms_noise_stage + 1 + + else: + noise_list = torch.cat((noise_list[-2:], noise_pred), dim=0) + + x = self.get_x_pred(x, noise_pred_prime, t_1, t_prev) + if export_after: + torch.onnx.export( + self.ad, + x.cpu(), + f"{project_name}_after.onnx", + input_names=["x"], + output_names=["mel_out"], + dynamic_axes={ + "x": [3] + }, + opset_version=16 + ) + x = self.ad(x) + + print((x == org_y_x).all()) + return x + + def forward(self, condition=None, init_noise=None, pndms=None, k_step=None): + cond = condition + x = init_noise + + device = cond.device + n_frames = cond.shape[2] + step_range = torch.arange(0, k_step.item(), pndms.item(), dtype=torch.long, device=device).flip(0) + plms_noise_stage = torch.tensor(0, dtype=torch.long, device=device) + noise_list = torch.zeros((0, 1, 1, self.mel_bins, n_frames), device=device) + + for t in step_range: + t_1 = torch.full((1,), t, device=device, dtype=torch.long) + noise_pred = self.denoise_fn(x, t_1, cond) + t_prev = t_1 - pndms + t_prev = t_prev * (t_prev > 0) + if plms_noise_stage == 0: + x_pred = self.get_x_pred(x, noise_pred, t_1, t_prev) + noise_pred_prev = self.denoise_fn(x_pred, t_prev, cond=cond) + noise_pred_prime = predict_stage0(noise_pred, noise_pred_prev) + + elif plms_noise_stage == 1: + noise_pred_prime = predict_stage1(noise_pred, noise_list) + + elif plms_noise_stage == 2: + noise_pred_prime = predict_stage2(noise_pred, noise_list) + + else: + noise_pred_prime = predict_stage3(noise_pred, noise_list) + + noise_pred = noise_pred.unsqueeze(0) + + if plms_noise_stage < 3: + noise_list = torch.cat((noise_list, noise_pred), dim=0) + plms_noise_stage = plms_noise_stage + 1 + + else: + noise_list = torch.cat((noise_list[-2:], noise_pred), dim=0) + + x = self.get_x_pred(x, noise_pred_prime, t_1, t_prev) + x = self.ad(x) + return x diff --git a/AIMeiSheng/diffuse_fang/diffusion/dpm_solver_pytorch.py b/AIMeiSheng/diffuse_fang/diffusion/dpm_solver_pytorch.py new file mode 100644 index 0000000..83ed73e --- /dev/null +++ b/AIMeiSheng/diffuse_fang/diffusion/dpm_solver_pytorch.py @@ -0,0 +1,1307 @@ +import torch + + +class NoiseScheduleVP: + def __init__( + self, + schedule='discrete', + betas=None, + alphas_cumprod=None, + continuous_beta_0=0.1, + continuous_beta_1=20., + dtype=torch.float32, + ): + """Create a wrapper class for the forward SDE (VP type). + + *** + Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t. + We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images. + *** + + The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ). + We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper). + Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have: + + log_alpha_t = self.marginal_log_mean_coeff(t) + sigma_t = self.marginal_std(t) + lambda_t = self.marginal_lambda(t) + + Moreover, as lambda(t) is an invertible function, we also support its inverse function: + + t = self.inverse_lambda(lambda_t) + + =============================================================== + + We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]). + + 1. For discrete-time DPMs: + + For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by: + t_i = (i + 1) / N + e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1. + We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3. + + Args: + betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details) + alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details) + + Note that we always have alphas_cumprod = cumprod(1 - betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`. + + **Important**: Please pay special attention for the args for `alphas_cumprod`: + The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that + q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ). + Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have + alpha_{t_n} = \sqrt{\hat{alpha_n}}, + and + log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}). + + + 2. For continuous-time DPMs: + + We support the linear VPSDE for the continuous time setting. The hyperparameters for the noise + schedule are the default settings in Yang Song's ScoreSDE: + + Args: + beta_min: A `float` number. The smallest beta for the linear schedule. + beta_max: A `float` number. The largest beta for the linear schedule. + T: A `float` number. The ending time of the forward process. + + =============================================================== + + Args: + schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs, + 'linear' for continuous-time DPMs. + Returns: + A wrapper object of the forward SDE (VP type). + + =============================================================== + + Example: + + # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', betas=betas) + + # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) + + # For continuous-time DPMs (VPSDE), linear schedule: + >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) + + """ + + if schedule not in ['discrete', 'linear']: + raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear'".format(schedule)) + + self.schedule = schedule + if schedule == 'discrete': + if betas is not None: + log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) + else: + assert alphas_cumprod is not None + log_alphas = 0.5 * torch.log(alphas_cumprod) + self.T = 1. + self.log_alpha_array = self.numerical_clip_alpha(log_alphas).reshape((1, -1,)).to(dtype=dtype) + self.total_N = self.log_alpha_array.shape[1] + self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype) + else: + self.T = 1. + self.total_N = 1000 + self.beta_0 = continuous_beta_0 + self.beta_1 = continuous_beta_1 + + def numerical_clip_alpha(self, log_alphas, clipped_lambda=-5.1): + """ + For some beta schedules such as cosine schedule, the log-SNR has numerical isssues. + We clip the log-SNR near t=T within -5.1 to ensure the stability. + Such a trick is very useful for diffusion models with the cosine schedule, such as i-DDPM, guided-diffusion and GLIDE. + """ + log_sigmas = 0.5 * torch.log(1. - torch.exp(2. * log_alphas)) + lambs = log_alphas - log_sigmas + idx = torch.searchsorted(torch.flip(lambs, [0]), clipped_lambda) + if idx > 0: + log_alphas = log_alphas[:-idx] + return log_alphas + + def marginal_log_mean_coeff(self, t): + """ + Compute log(alpha_t) of a given continuous-time label t in [0, T]. + """ + if self.schedule == 'discrete': + return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1)) + elif self.schedule == 'linear': + return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 + + def marginal_alpha(self, t): + """ + Compute alpha_t of a given continuous-time label t in [0, T]. + """ + return torch.exp(self.marginal_log_mean_coeff(t)) + + def marginal_std(self, t): + """ + Compute sigma_t of a given continuous-time label t in [0, T]. + """ + return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t))) + + def marginal_lambda(self, t): + """ + Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. + """ + log_mean_coeff = self.marginal_log_mean_coeff(t) + log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) + return log_mean_coeff - log_std + + def inverse_lambda(self, lamb): + """ + Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. + """ + if self.schedule == 'linear': + tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + Delta = self.beta_0**2 + tmp + return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) + elif self.schedule == 'discrete': + log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb) + t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1])) + return t.reshape((-1,)) + + +def model_wrapper( + model, + noise_schedule, + model_type="noise", + model_kwargs={}, + guidance_type="uncond", + condition=None, + unconditional_condition=None, + guidance_scale=1., + classifier_fn=None, + classifier_kwargs={}, +): + """Create a wrapper function for the noise prediction model. + + DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to + firstly wrap the model function to a noise prediction model that accepts the continuous time as the input. + + We support four types of the diffusion model by setting `model_type`: + + 1. "noise": noise prediction model. (Trained by predicting noise). + + 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0). + + 3. "v": velocity prediction model. (Trained by predicting the velocity). + The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2]. + + [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models." + arXiv preprint arXiv:2202.00512 (2022). + [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models." + arXiv preprint arXiv:2210.02303 (2022). + + 4. "score": marginal score function. (Trained by denoising score matching). + Note that the score function and the noise prediction model follows a simple relationship: + ``` + noise(x_t, t) = -sigma_t * score(x_t, t) + ``` + + We support three types of guided sampling by DPMs by setting `guidance_type`: + 1. "uncond": unconditional sampling by DPMs. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + + 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + + The input `classifier_fn` has the following format: + `` + classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond) + `` + + [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis," + in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794. + + 3. "classifier-free": classifier-free guidance sampling by conditional DPMs. + The input `model` has the following format: + `` + model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score + `` + And if cond == `unconditional_condition`, the model output is the unconditional DPM output. + + [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance." + arXiv preprint arXiv:2207.12598 (2022). + + + The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999) + or continuous-time labels (i.e. epsilon to T). + + We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise: + `` + def model_fn(x, t_continuous) -> noise: + t_input = get_model_input_time(t_continuous) + return noise_pred(model, x, t_input, **model_kwargs) + `` + where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver. + + =============================================================== + + Args: + model: A diffusion model with the corresponding format described above. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + model_type: A `str`. The parameterization type of the diffusion model. + "noise" or "x_start" or "v" or "score". + model_kwargs: A `dict`. A dict for the other inputs of the model function. + guidance_type: A `str`. The type of the guidance for sampling. + "uncond" or "classifier" or "classifier-free". + condition: A pytorch tensor. The condition for the guided sampling. + Only used for "classifier" or "classifier-free" guidance type. + unconditional_condition: A pytorch tensor. The condition for the unconditional sampling. + Only used for "classifier-free" guidance type. + guidance_scale: A `float`. The scale for the guided sampling. + classifier_fn: A classifier function. Only used for the classifier guidance. + classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function. + Returns: + A noise prediction model that accepts the noised data and the continuous time as the inputs. + """ + + def get_model_input_time(t_continuous): + """ + Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. + For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. + For continuous-time DPMs, we just use `t_continuous`. + """ + if noise_schedule.schedule == 'discrete': + return (t_continuous - 1. / noise_schedule.total_N) * noise_schedule.total_N + else: + return t_continuous + + def noise_pred_fn(x, t_continuous, cond=None): + t_input = get_model_input_time(t_continuous) + if cond is None: + output = model(x, t_input, **model_kwargs) + else: + output = model(x, t_input, cond, **model_kwargs) + if model_type == "noise": + return output + elif model_type == "x_start": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + return (x - expand_dims(alpha_t, x.dim()) * output) / expand_dims(sigma_t, x.dim()) + elif model_type == "v": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + return expand_dims(alpha_t, x.dim()) * output + expand_dims(sigma_t, x.dim()) * x + elif model_type == "score": + sigma_t = noise_schedule.marginal_std(t_continuous) + return -expand_dims(sigma_t, x.dim()) * output + + def cond_grad_fn(x, t_input): + """ + Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). + """ + with torch.enable_grad(): + x_in = x.detach().requires_grad_(True) + log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs) + return torch.autograd.grad(log_prob.sum(), x_in)[0] + + def model_fn(x, t_continuous): + """ + The noise predicition model function that is used for DPM-Solver. + """ + if guidance_type == "uncond": + return noise_pred_fn(x, t_continuous) + elif guidance_type == "classifier": + assert classifier_fn is not None + t_input = get_model_input_time(t_continuous) + cond_grad = cond_grad_fn(x, t_input) + sigma_t = noise_schedule.marginal_std(t_continuous) + noise = noise_pred_fn(x, t_continuous) + return noise - guidance_scale * expand_dims(sigma_t, x.dim()) * cond_grad + elif guidance_type == "classifier-free": + if guidance_scale == 1. or unconditional_condition is None: + return noise_pred_fn(x, t_continuous, cond=condition) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t_continuous] * 2) + c_in = torch.cat([unconditional_condition, condition]) + noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) + return noise_uncond + guidance_scale * (noise - noise_uncond) + + assert model_type in ["noise", "x_start", "v", "score"] + assert guidance_type in ["uncond", "classifier", "classifier-free"] + return model_fn + + +class DPM_Solver: + def __init__( + self, + model_fn, + noise_schedule, + algorithm_type="dpmsolver++", + correcting_x0_fn=None, + correcting_xt_fn=None, + thresholding_max_val=1., + dynamic_thresholding_ratio=0.995, + ): + """Construct a DPM-Solver. + + We support both DPM-Solver (`algorithm_type="dpmsolver"`) and DPM-Solver++ (`algorithm_type="dpmsolver++"`). + + We also support the "dynamic thresholding" method in Imagen[1]. For pixel-space diffusion models, you + can set both `algorithm_type="dpmsolver++"` and `correcting_x0_fn="dynamic_thresholding"` to use the + dynamic thresholding. The "dynamic thresholding" can greatly improve the sample quality for pixel-space + DPMs with large guidance scales. Note that the thresholding method is **unsuitable** for latent-space + DPMs (such as stable-diffusion). + + To support advanced algorithms in image-to-image applications, we also support corrector functions for + both x0 and xt. + + Args: + model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]): + `` + def model_fn(x, t_continuous): + return noise + `` + The shape of `x` is `(batch_size, **shape)`, and the shape of `t_continuous` is `(batch_size,)`. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + algorithm_type: A `str`. Either "dpmsolver" or "dpmsolver++". + correcting_x0_fn: A `str` or a function with the following format: + ``` + def correcting_x0_fn(x0, t): + x0_new = ... + return x0_new + ``` + This function is to correct the outputs of the data prediction model at each sampling step. e.g., + ``` + x0_pred = data_pred_model(xt, t) + if correcting_x0_fn is not None: + x0_pred = correcting_x0_fn(x0_pred, t) + xt_1 = update(x0_pred, xt, t) + ``` + If `correcting_x0_fn="dynamic_thresholding"`, we use the dynamic thresholding proposed in Imagen[1]. + correcting_xt_fn: A function with the following format: + ``` + def correcting_xt_fn(xt, t, step): + x_new = ... + return x_new + ``` + This function is to correct the intermediate samples xt at each sampling step. e.g., + ``` + xt = ... + xt = correcting_xt_fn(xt, t, step) + ``` + thresholding_max_val: A `float`. The max value for thresholding. + Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`. + dynamic_thresholding_ratio: A `float`. The ratio for dynamic thresholding (see Imagen[1] for details). + Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`. + + [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, + Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models + with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b. + """ + self.model = lambda x, t: model_fn(x, t.expand((x.shape[0]))) + self.noise_schedule = noise_schedule + assert algorithm_type in ["dpmsolver", "dpmsolver++"] + self.algorithm_type = algorithm_type + if correcting_x0_fn == "dynamic_thresholding": + self.correcting_x0_fn = self.dynamic_thresholding_fn + else: + self.correcting_x0_fn = correcting_x0_fn + self.correcting_xt_fn = correcting_xt_fn + self.dynamic_thresholding_ratio = dynamic_thresholding_ratio + self.thresholding_max_val = thresholding_max_val + + def dynamic_thresholding_fn(self, x0, t): + """ + The dynamic thresholding method. + """ + dims = x0.dim() + p = self.dynamic_thresholding_ratio + s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) + s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims) + x0 = torch.clamp(x0, -s, s) / s + return x0 + + def noise_prediction_fn(self, x, t): + """ + Return the noise prediction model. + """ + return self.model(x, t) + + def data_prediction_fn(self, x, t): + """ + Return the data prediction model (with corrector). + """ + noise = self.noise_prediction_fn(x, t) + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + x0 = (x - sigma_t * noise) / alpha_t + if self.correcting_x0_fn is not None: + x0 = self.correcting_x0_fn(x0, t) + return x0 + + def model_fn(self, x, t): + """ + Convert the model to the noise prediction model or the data prediction model. + """ + if self.algorithm_type == "dpmsolver++": + return self.data_prediction_fn(x, t) + else: + return self.noise_prediction_fn(x, t) + + def get_time_steps(self, skip_type, t_T, t_0, N, device): + """Compute the intermediate time steps for sampling. + + Args: + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + N: A `int`. The total number of the spacing of the time steps. + device: A torch device. + Returns: + A pytorch tensor of the time steps, with the shape (N + 1,). + """ + if skip_type == 'logSNR': + lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) + lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) + logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) + return self.noise_schedule.inverse_lambda(logSNR_steps) + elif skip_type == 'time_uniform': + return torch.linspace(t_T, t_0, N + 1).to(device) + elif skip_type == 'time_quadratic': + t_order = 2 + t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device) + return t + else: + raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) + + def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): + """ + Get the order of each step for sampling by the singlestep DPM-Solver. + + We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast". + Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is: + - If order == 1: + We take `steps` of DPM-Solver-1 (i.e. DDIM). + - If order == 2: + - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of DPM-Solver-2. + - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If order == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2. + + ============================================ + Args: + order: A `int`. The max order for the solver (2 or 3). + steps: A `int`. The total number of function evaluations (NFE). + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + device: A torch device. + Returns: + orders: A list of the solver order of each step. + """ + if order == 3: + K = steps // 3 + 1 + if steps % 3 == 0: + orders = [3,] * (K - 2) + [2, 1] + elif steps % 3 == 1: + orders = [3,] * (K - 1) + [1] + else: + orders = [3,] * (K - 1) + [2] + elif order == 2: + if steps % 2 == 0: + K = steps // 2 + orders = [2,] * K + else: + K = steps // 2 + 1 + orders = [2,] * (K - 1) + [1] + elif order == 1: + K = 1 + orders = [1,] * steps + else: + raise ValueError("'order' must be '1' or '2' or '3'.") + if skip_type == 'logSNR': + # To reproduce the results in DPM-Solver paper + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) + else: + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders), 0).to(device)] + return timesteps_outer, orders + + def denoise_to_zero_fn(self, x, s): + """ + Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization. + """ + return self.data_prediction_fn(x, s) + + def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False): + """ + DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + if self.algorithm_type == "dpmsolver++": + phi_1 = torch.expm1(-h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = ( + sigma_t / sigma_s * x + - alpha_t * phi_1 * model_s + ) + if return_intermediate: + return x_t, {'model_s': model_s} + else: + return x_t + else: + phi_1 = torch.expm1(h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = ( + torch.exp(log_alpha_t - log_alpha_s) * x + - (sigma_t * phi_1) * model_s + ) + if return_intermediate: + return x_t, {'model_s': model_s} + else: + return x_t + + def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, solver_type='dpmsolver'): + """ + Singlestep solver DPM-Solver-2 from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + r1: A `float`. The hyperparameter of the second-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpmsolver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 0.5 + ns = self.noise_schedule + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + s1 = ns.inverse_lambda(lambda_s1) + log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t) + alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t) + + if self.algorithm_type == "dpmsolver++": + phi_11 = torch.expm1(-r1 * h) + phi_1 = torch.expm1(-h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = ( + (sigma_s1 / sigma_s) * x + - (alpha_s1 * phi_11) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + if solver_type == 'dpmsolver': + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + - (0.5 / r1) * (alpha_t * phi_1) * (model_s1 - model_s) + ) + elif solver_type == 'taylor': + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + + (1. / r1) * (alpha_t * (phi_1 / h + 1.)) * (model_s1 - model_s) + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_1 = torch.expm1(h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = ( + torch.exp(log_alpha_s1 - log_alpha_s) * x + - (sigma_s1 * phi_11) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + if solver_type == 'dpmsolver': + x_t = ( + torch.exp(log_alpha_t - log_alpha_s) * x + - (sigma_t * phi_1) * model_s + - (0.5 / r1) * (sigma_t * phi_1) * (model_s1 - model_s) + ) + elif solver_type == 'taylor': + x_t = ( + torch.exp(log_alpha_t - log_alpha_s) * x + - (sigma_t * phi_1) * model_s + - (1. / r1) * (sigma_t * (phi_1 / h - 1.)) * (model_s1 - model_s) + ) + if return_intermediate: + return x_t, {'model_s': model_s, 'model_s1': model_s1} + else: + return x_t + + def singlestep_dpm_solver_third_update(self, x, s, t, r1=1./3., r2=2./3., model_s=None, model_s1=None, return_intermediate=False, solver_type='dpmsolver'): + """ + Singlestep solver DPM-Solver-3 from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + r1: A `float`. The hyperparameter of the third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`). + If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpmsolver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 1. / 3. + if r2 is None: + r2 = 2. / 3. + ns = self.noise_schedule + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + lambda_s2 = lambda_s + r2 * h + s1 = ns.inverse_lambda(lambda_s1) + s2 = ns.inverse_lambda(lambda_s2) + log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(s2), ns.marginal_std(t) + alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t) + + if self.algorithm_type == "dpmsolver++": + phi_11 = torch.expm1(-r1 * h) + phi_12 = torch.expm1(-r2 * h) + phi_1 = torch.expm1(-h) + phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1. + phi_2 = phi_1 / h + 1. + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = ( + (sigma_s1 / sigma_s) * x + - (alpha_s1 * phi_11) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + (sigma_s2 / sigma_s) * x + - (alpha_s2 * phi_12) * model_s + + r2 / r1 * (alpha_s2 * phi_22) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == 'dpmsolver': + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + + (1. / r2) * (alpha_t * phi_2) * (model_s2 - model_s) + ) + elif solver_type == 'taylor': + D1_0 = (1. / r1) * (model_s1 - model_s) + D1_1 = (1. / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + + (alpha_t * phi_2) * D1 + - (alpha_t * phi_3) * D2 + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_12 = torch.expm1(r2 * h) + phi_1 = torch.expm1(h) + phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1. + phi_2 = phi_1 / h - 1. + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = ( + (torch.exp(log_alpha_s1 - log_alpha_s)) * x + - (sigma_s1 * phi_11) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + (torch.exp(log_alpha_s2 - log_alpha_s)) * x + - (sigma_s2 * phi_12) * model_s + - r2 / r1 * (sigma_s2 * phi_22) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == 'dpmsolver': + x_t = ( + (torch.exp(log_alpha_t - log_alpha_s)) * x + - (sigma_t * phi_1) * model_s + - (1. / r2) * (sigma_t * phi_2) * (model_s2 - model_s) + ) + elif solver_type == 'taylor': + D1_0 = (1. / r1) * (model_s1 - model_s) + D1_1 = (1. / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + (torch.exp(log_alpha_t - log_alpha_s)) * x + - (sigma_t * phi_1) * model_s + - (sigma_t * phi_2) * D1 + - (sigma_t * phi_3) * D2 + ) + + if return_intermediate: + return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2} + else: + return x_t + + def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"): + """ + Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) + t: A pytorch tensor. The ending time, with the shape (1,). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpmsolver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type)) + ns = self.noise_schedule + model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1] + t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1] + lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0 = h_0 / h + D1_0 = (1. / r0) * (model_prev_0 - model_prev_1) + if self.algorithm_type == "dpmsolver++": + phi_1 = torch.expm1(-h) + if solver_type == 'dpmsolver': + x_t = ( + (sigma_t / sigma_prev_0) * x + - (alpha_t * phi_1) * model_prev_0 + - 0.5 * (alpha_t * phi_1) * D1_0 + ) + elif solver_type == 'taylor': + x_t = ( + (sigma_t / sigma_prev_0) * x + - (alpha_t * phi_1) * model_prev_0 + + (alpha_t * (phi_1 / h + 1.)) * D1_0 + ) + else: + phi_1 = torch.expm1(h) + if solver_type == 'dpmsolver': + x_t = ( + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * phi_1) * model_prev_0 + - 0.5 * (sigma_t * phi_1) * D1_0 + ) + elif solver_type == 'taylor': + x_t = ( + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * phi_1) * model_prev_0 + - (sigma_t * (phi_1 / h - 1.)) * D1_0 + ) + return x_t + + def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpmsolver'): + """ + Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) + t: A pytorch tensor. The ending time, with the shape (1,). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + model_prev_2, model_prev_1, model_prev_0 = model_prev_list + t_prev_2, t_prev_1, t_prev_0 = t_prev_list + lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_1 = lambda_prev_1 - lambda_prev_2 + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0, r1 = h_0 / h, h_1 / h + D1_0 = (1. / r0) * (model_prev_0 - model_prev_1) + D1_1 = (1. / r1) * (model_prev_1 - model_prev_2) + D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1. / (r0 + r1)) * (D1_0 - D1_1) + if self.algorithm_type == "dpmsolver++": + phi_1 = torch.expm1(-h) + phi_2 = phi_1 / h + 1. + phi_3 = phi_2 / h - 0.5 + x_t = ( + (sigma_t / sigma_prev_0) * x + - (alpha_t * phi_1) * model_prev_0 + + (alpha_t * phi_2) * D1 + - (alpha_t * phi_3) * D2 + ) + else: + phi_1 = torch.expm1(h) + phi_2 = phi_1 / h - 1. + phi_3 = phi_2 / h - 0.5 + x_t = ( + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * phi_1) * model_prev_0 + - (sigma_t * phi_2) * D1 + - (sigma_t * phi_3) * D2 + ) + return x_t + + def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpmsolver', r1=None, r2=None): + """ + Singlestep DPM-Solver with the order `order` from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + r1: A `float`. The hyperparameter of the second-order or third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate) + elif order == 2: + return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1) + elif order == 3: + return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1, r2=r2) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpmsolver'): + """ + Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) + t: A pytorch tensor. The ending time, with the shape (1,). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1]) + elif order == 2: + return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + elif order == 3: + return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, solver_type='dpmsolver'): + """ + The adaptive step size solver based on singlestep DPM-Solver. + + Args: + x: A pytorch tensor. The initial value at time `t_T`. + order: A `int`. The (higher) order of the solver. We only support order == 2 or 3. + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + h_init: A `float`. The initial step size (for logSNR). + atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1]. + rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05. + theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1]. + t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the + current time and `t_0` is less than `t_err`. The default setting is 1e-5. + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_0: A pytorch tensor. The approximated solution at time `t_0`. + + [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021. + """ + ns = self.noise_schedule + s = t_T * torch.ones((1,)).to(x) + lambda_s = ns.marginal_lambda(s) + lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x)) + h = h_init * torch.ones_like(s).to(x) + x_prev = x + nfe = 0 + if order == 2: + r1 = 0.5 + def lower_update(x, s, t): + return self.dpm_solver_first_update(x, s, t, return_intermediate=True) + def higher_update(x, s, t, **kwargs): + return self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, solver_type=solver_type, **kwargs) + elif order == 3: + r1, r2 = 1. / 3., 2. / 3. + def lower_update(x, s, t): + return self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type) + def higher_update(x, s, t, **kwargs): + return self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs) + else: + raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order)) + while torch.abs((s - t_0)).mean() > t_err: + t = ns.inverse_lambda(lambda_s + h) + x_lower, lower_noise_kwargs = lower_update(x, s, t) + x_higher = higher_update(x, s, t, **lower_noise_kwargs) + delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev))) + def norm_fn(v): + return torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)) + E = norm_fn((x_higher - x_lower) / delta).max() + if torch.all(E <= 1.): + x = x_higher + s = t + x_prev = x_lower + lambda_s = ns.marginal_lambda(s) + h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s) + nfe += order + print('adaptive solver nfe', nfe) + return x + + def add_noise(self, x, t, noise=None): + """ + Compute the noised input xt = alpha_t * x + sigma_t * noise. + + Args: + x: A `torch.Tensor` with shape `(batch_size, *shape)`. + t: A `torch.Tensor` with shape `(t_size,)`. + Returns: + xt with shape `(t_size, batch_size, *shape)`. + """ + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + if noise is None: + noise = torch.randn((t.shape[0], *x.shape), device=x.device) + x = x.reshape((-1, *x.shape)) + xt = expand_dims(alpha_t, x.dim()) * x + expand_dims(sigma_t, x.dim()) * noise + if t.shape[0] == 1: + return xt.squeeze(0) + else: + return xt + + def inverse(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform', + method='multistep', lower_order_final=True, denoise_to_zero=False, solver_type='dpmsolver', + atol=0.0078, rtol=0.05, return_intermediate=False, + ): + """ + Inverse the sample `x` from time `t_start` to `t_end` by DPM-Solver. + For discrete-time DPMs, we use `t_start=1/N`, where `N` is the total time steps during training. + """ + t_0 = 1. / self.noise_schedule.total_N if t_start is None else t_start + t_T = self.noise_schedule.T if t_end is None else t_end + assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array" + return self.sample(x, steps=steps, t_start=t_0, t_end=t_T, order=order, skip_type=skip_type, + method=method, lower_order_final=lower_order_final, denoise_to_zero=denoise_to_zero, solver_type=solver_type, + atol=atol, rtol=rtol, return_intermediate=return_intermediate) + + def sample(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform', + method='multistep', lower_order_final=True, denoise_to_zero=False, solver_type='dpmsolver', + atol=0.0078, rtol=0.05, return_intermediate=False, + ): + """ + Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`. + + ===================================================== + + We support the following algorithms for both noise prediction model and data prediction model: + - 'singlestep': + Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver. + We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps). + The total number of function evaluations (NFE) == `steps`. + Given a fixed NFE == `steps`, the sampling procedure is: + - If `order` == 1: + - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2. + - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If `order` == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2. + - 'multistep': + Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`. + We initialize the first `order` values by lower order multistep solvers. + Given a fixed NFE == `steps`, the sampling procedure is: + Denote K = steps. + - If `order` == 1: + - We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2. + - If `order` == 3: + - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3. + - 'singlestep_fixed': + Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3). + We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE. + - 'adaptive': + Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper). + We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`. + You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs + (NFE) and the sample quality. + - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2. + - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3. + + ===================================================== + + Some advices for choosing the algorithm: + - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs: + Use singlestep DPM-Solver or DPM-Solver++ ("DPM-Solver-fast" in the paper) with `order = 3`. + e.g., DPM-Solver: + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver") + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, + skip_type='time_uniform', method='singlestep') + e.g., DPM-Solver++: + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, + skip_type='time_uniform', method='singlestep') + - For **guided sampling with large guidance scale** by DPMs: + Use multistep DPM-Solver with `algorithm_type="dpmsolver++"` and `order = 2`. + e.g. + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2, + skip_type='time_uniform', method='multistep') + + We support three types of `skip_type`: + - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images** + - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**. + - 'time_quadratic': quadratic time for the time steps. + + ===================================================== + Args: + x: A pytorch tensor. The initial value at time `t_start` + e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution. + steps: A `int`. The total number of function evaluations (NFE). + t_start: A `float`. The starting time of the sampling. + If `T` is None, we use self.noise_schedule.T (default is 1.0). + t_end: A `float`. The ending time of the sampling. + If `t_end` is None, we use 1. / self.noise_schedule.total_N. + e.g. if total_N == 1000, we have `t_end` == 1e-3. + For discrete-time DPMs: + - We recommend `t_end` == 1. / self.noise_schedule.total_N. + For continuous-time DPMs: + - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15. + order: A `int`. The order of DPM-Solver. + skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'. + method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'. + denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step. + Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1). + + This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and + score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID + for diffusion models sampling by diffusion SDEs for low-resolutional images + (such as CIFAR-10). However, we observed that such trick does not matter for + high-resolutional images. As it needs an additional NFE, we do not recommend + it for high-resolutional images. + lower_order_final: A `bool`. Whether to use lower order solvers at the final steps. + Only valid for `method=multistep` and `steps < 15`. We empirically find that + this trick is a key to stabilizing the sampling by DPM-Solver with very few steps + (especially for steps <= 10). So we recommend to set it to be `True`. + solver_type: A `str`. The taylor expansion type for the solver. `dpmsolver` or `taylor`. We recommend `dpmsolver`. + atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + return_intermediate: A `bool`. Whether to save the xt at each step. + When set to `True`, method returns a tuple (x0, intermediates); when set to False, method returns only x0. + Returns: + x_end: A pytorch tensor. The approximated solution at time `t_end`. + + """ + t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end + t_T = self.noise_schedule.T if t_start is None else t_start + assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array" + if return_intermediate: + assert method in ['multistep', 'singlestep', 'singlestep_fixed'], "Cannot use adaptive solver when saving intermediate values" + if self.correcting_xt_fn is not None: + assert method in ['multistep', 'singlestep', 'singlestep_fixed'], "Cannot use adaptive solver when correcting_xt_fn is not None" + device = x.device + intermediates = [] + with torch.no_grad(): + if method == 'adaptive': + x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, solver_type=solver_type) + elif method == 'multistep': + assert steps >= order + timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) + assert timesteps.shape[0] - 1 == steps + # Init the initial values. + step = 0 + t = timesteps[step] + t_prev_list = [t] + model_prev_list = [self.model_fn(x, t)] + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + # Init the first `order` values by lower order multistep DPM-Solver. + for step in range(1, order): + t = timesteps[step] + x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, t, step, solver_type=solver_type) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + t_prev_list.append(t) + model_prev_list.append(self.model_fn(x, t)) + # Compute the remaining values by `order`-th order multistep DPM-Solver. + for step in range(order, steps + 1): + t = timesteps[step] + # We only use lower order for steps < 10 + if lower_order_final and steps < 10: + step_order = min(order, steps + 1 - step) + else: + step_order = order + x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, t, step_order, solver_type=solver_type) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + for i in range(order - 1): + t_prev_list[i] = t_prev_list[i + 1] + model_prev_list[i] = model_prev_list[i + 1] + t_prev_list[-1] = t + # We do not need to evaluate the final model value. + if step < steps: + model_prev_list[-1] = self.model_fn(x, t) + elif method in ['singlestep', 'singlestep_fixed']: + if method == 'singlestep': + timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order, skip_type=skip_type, t_T=t_T, t_0=t_0, device=device) + elif method == 'singlestep_fixed': + K = steps // order + orders = [order,] * K + timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device) + for step, order in enumerate(orders): + s, t = timesteps_outer[step], timesteps_outer[step + 1] + timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=s.item(), t_0=t.item(), N=order, device=device) + lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner) + h = lambda_inner[-1] - lambda_inner[0] + r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h + r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h + x = self.singlestep_dpm_solver_update(x, s, t, order, solver_type=solver_type, r1=r1, r2=r2) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + else: + raise ValueError("Got wrong method {}".format(method)) + if denoise_to_zero: + t = torch.ones((1,)).to(device) * t_0 + x = self.denoise_to_zero_fn(x, t) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step + 1) + if return_intermediate: + intermediates.append(x) + if return_intermediate: + return x, intermediates + else: + return x + + + +############################################################# +# other utility functions +############################################################# + +def interpolate_fn(x, xp, yp): + """ + A piecewise linear function y = f(x), using xp and yp as keypoints. + We implement f(x) in a differentiable way (i.e. applicable for autograd). + The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.) + + Args: + x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver). + xp: PyTorch tensor with shape [C, K], where K is the number of keypoints. + yp: PyTorch tensor with shape [C, K]. + Returns: + The function values f(x), with shape [N, C]. + """ + N, K = x.shape[0], xp.shape[1] + all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) + sorted_all_x, x_indices = torch.sort(all_x, dim=2) + x_idx = torch.argmin(x_indices, dim=2) + cand_start_idx = x_idx - 1 + start_idx = torch.where( + torch.eq(x_idx, 0), + torch.tensor(1, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) + start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) + end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) + start_idx2 = torch.where( + torch.eq(x_idx, 0), + torch.tensor(0, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) + start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) + end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) + cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) + return cand + + +def expand_dims(v, dims): + """ + Expand the tensor `v` to the dim `dims`. + + Args: + `v`: a PyTorch tensor with shape [N]. + `dim`: a `int`. + Returns: + a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. + """ + return v[(...,) + (None,)*(dims - 1)] \ No newline at end of file diff --git a/AIMeiSheng/diffuse_fang/diffusion/how to export onnx.md b/AIMeiSheng/diffuse_fang/diffusion/how to export onnx.md new file mode 100644 index 0000000..5aae72c --- /dev/null +++ b/AIMeiSheng/diffuse_fang/diffusion/how to export onnx.md @@ -0,0 +1,4 @@ +- Open [onnx_export](onnx_export.py) +- project_name = "dddsp" change "project_name" to your project name +- model_path = f'{project_name}/model_500000.pt' change "model_path" to your model path +- Run \ No newline at end of file diff --git a/AIMeiSheng/diffuse_fang/diffusion/infer_gt_mel.py b/AIMeiSheng/diffuse_fang/diffusion/infer_gt_mel.py new file mode 100644 index 0000000..0bdf1fe --- /dev/null +++ b/AIMeiSheng/diffuse_fang/diffusion/infer_gt_mel.py @@ -0,0 +1,74 @@ +import torch +import torch.nn.functional as F + +from diffusion.unit2mel import load_model_vocoder + + +class DiffGtMel: + def __init__(self, project_path=None, device=None): + self.project_path = project_path + if device is not None: + self.device = device + else: + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.model = None + self.vocoder = None + self.args = None + + def flush_model(self, project_path, ddsp_config=None): + if (self.model is None) or (project_path != self.project_path): + model, vocoder, args = load_model_vocoder(project_path, device=self.device) + if self.check_args(ddsp_config, args): + self.model = model + self.vocoder = vocoder + self.args = args + + def check_args(self, args1, args2): + if args1.data.block_size != args2.data.block_size: + raise ValueError("DDSP与DIFF模型的block_size不一致") + if args1.data.sampling_rate != args2.data.sampling_rate: + raise ValueError("DDSP与DIFF模型的sampling_rate不一致") + if args1.data.encoder != args2.data.encoder: + raise ValueError("DDSP与DIFF模型的encoder不一致") + return True + + def __call__(self, audio, f0, hubert, volume, acc=1, spk_id=1, k_step=0, method='pndm', + spk_mix_dict=None, start_frame=0): + input_mel = self.vocoder.extract(audio, self.args.data.sampling_rate) + out_mel = self.model( + hubert, + f0, + volume, + spk_id=spk_id, + spk_mix_dict=spk_mix_dict, + gt_spec=input_mel, + infer=True, + infer_speedup=acc, + method=method, + k_step=k_step, + use_tqdm=False) + if start_frame > 0: + out_mel = out_mel[:, start_frame:, :] + f0 = f0[:, start_frame:, :] + output = self.vocoder.infer(out_mel, f0) + if start_frame > 0: + output = F.pad(output, (start_frame * self.vocoder.vocoder_hop_size, 0)) + return output + + def infer(self, audio, f0, hubert, volume, acc=1, spk_id=1, k_step=0, method='pndm', silence_front=0, + use_silence=False, spk_mix_dict=None): + start_frame = int(silence_front * self.vocoder.vocoder_sample_rate / self.vocoder.vocoder_hop_size) + if use_silence: + audio = audio[:, start_frame * self.vocoder.vocoder_hop_size:] + f0 = f0[:, start_frame:, :] + hubert = hubert[:, start_frame:, :] + volume = volume[:, start_frame:, :] + _start_frame = 0 + else: + _start_frame = start_frame + audio = self.__call__(audio, f0, hubert, volume, acc=acc, spk_id=spk_id, k_step=k_step, + method=method, spk_mix_dict=spk_mix_dict, start_frame=_start_frame) + if use_silence: + if start_frame > 0: + audio = F.pad(audio, (start_frame * self.vocoder.vocoder_hop_size, 0)) + return audio diff --git a/AIMeiSheng/diffuse_fang/diffusion/logger/__init__.py b/AIMeiSheng/diffuse_fang/diffusion/logger/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/AIMeiSheng/diffuse_fang/diffusion/logger/saver.py b/AIMeiSheng/diffuse_fang/diffusion/logger/saver.py new file mode 100644 index 0000000..954ce99 --- /dev/null +++ b/AIMeiSheng/diffuse_fang/diffusion/logger/saver.py @@ -0,0 +1,145 @@ +''' +author: wayn391@mastertones +''' + +import datetime +import os +import time + +import matplotlib.pyplot as plt +import torch +import yaml +from torch.utils.tensorboard import SummaryWriter + + +class Saver(object): + def __init__( + self, + args, + initial_global_step=-1): + + self.expdir = args.env.expdir + self.sample_rate = args.data.sampling_rate + + # cold start + self.global_step = initial_global_step + self.init_time = time.time() + self.last_time = time.time() + + # makedirs + os.makedirs(self.expdir, exist_ok=True) + + # path + self.path_log_info = os.path.join(self.expdir, 'log_info.txt') + + # ckpt + os.makedirs(self.expdir, exist_ok=True) + + # writer + self.writer = SummaryWriter(os.path.join(self.expdir, 'logs')) + + # save config + path_config = os.path.join(self.expdir, 'config.yaml') + with open(path_config, "w") as out_config: + yaml.dump(dict(args), out_config) + + + def log_info(self, msg): + '''log method''' + if isinstance(msg, dict): + msg_list = [] + for k, v in msg.items(): + tmp_str = '' + if isinstance(v, int): + tmp_str = '{}: {:,}'.format(k, v) + else: + tmp_str = '{}: {}'.format(k, v) + + msg_list.append(tmp_str) + msg_str = '\n'.join(msg_list) + else: + msg_str = msg + + # dsplay + print(msg_str) + + # save + with open(self.path_log_info, 'a') as fp: + fp.write(msg_str+'\n') + + def log_value(self, dict): + for k, v in dict.items(): + self.writer.add_scalar(k, v, self.global_step) + + def log_spec(self, name, spec, spec_out, vmin=-14, vmax=3.5): + spec_cat = torch.cat([(spec_out - spec).abs() + vmin, spec, spec_out], -1) + spec = spec_cat[0] + if isinstance(spec, torch.Tensor): + spec = spec.cpu().numpy() + fig = plt.figure(figsize=(12, 9)) + plt.pcolor(spec.T, vmin=vmin, vmax=vmax) + plt.tight_layout() + self.writer.add_figure(name, fig, self.global_step) + + def log_audio(self, dict): + for k, v in dict.items(): + self.writer.add_audio(k, v, global_step=self.global_step, sample_rate=self.sample_rate) + + def get_interval_time(self, update=True): + cur_time = time.time() + time_interval = cur_time - self.last_time + if update: + self.last_time = cur_time + return time_interval + + def get_total_time(self, to_str=True): + total_time = time.time() - self.init_time + if to_str: + total_time = str(datetime.timedelta( + seconds=total_time))[:-5] + return total_time + + def save_model( + self, + model, + optimizer, + name='model', + postfix='', + to_json=False): + # path + if postfix: + postfix = '_' + postfix + path_pt = os.path.join( + self.expdir , name+postfix+'.pt') + + # check + print(' [*] model checkpoint saved: {}'.format(path_pt)) + + # save + if optimizer is not None: + torch.save({ + 'global_step': self.global_step, + 'model': model.state_dict(), + 'optimizer': optimizer.state_dict()}, path_pt) + else: + torch.save({ + 'global_step': self.global_step, + 'model': model.state_dict()}, path_pt) + + + def delete_model(self, name='model', postfix=''): + # path + if postfix: + postfix = '_' + postfix + path_pt = os.path.join( + self.expdir , name+postfix+'.pt') + + # delete + if os.path.exists(path_pt): + os.remove(path_pt) + print(' [*] model checkpoint deleted: {}'.format(path_pt)) + + def global_step_increment(self): + self.global_step += 1 + + diff --git a/AIMeiSheng/diffuse_fang/diffusion/logger/utils.py b/AIMeiSheng/diffuse_fang/diffusion/logger/utils.py new file mode 100644 index 0000000..a907de7 --- /dev/null +++ b/AIMeiSheng/diffuse_fang/diffusion/logger/utils.py @@ -0,0 +1,127 @@ +import json +import os + +import torch +import yaml + + +def traverse_dir( + root_dir, + extensions, + amount=None, + str_include=None, + str_exclude=None, + is_pure=False, + is_sort=False, + is_ext=True): + + file_list = [] + cnt = 0 + for root, _, files in os.walk(root_dir): + for file in files: + if any([file.endswith(f".{ext}") for ext in extensions]): + # path + mix_path = os.path.join(root, file) + pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path + + # amount + if (amount is not None) and (cnt == amount): + if is_sort: + file_list.sort() + return file_list + + # check string + if (str_include is not None) and (str_include not in pure_path): + continue + if (str_exclude is not None) and (str_exclude in pure_path): + continue + + if not is_ext: + ext = pure_path.split('.')[-1] + pure_path = pure_path[:-(len(ext)+1)] + file_list.append(pure_path) + cnt += 1 + if is_sort: + file_list.sort() + return file_list + + + +class DotDict(dict): + def __getattr__(*args): + val = dict.get(*args) + return DotDict(val) if type(val) is dict else val + + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ + + +def get_network_paras_amount(model_dict): + info = dict() + for model_name, model in model_dict.items(): + # all_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + info[model_name] = trainable_params + return info + + +def load_config(path_config): + with open(path_config, "r") as config: + args = yaml.safe_load(config) + args = DotDict(args) + # print(args) + return args + +def save_config(path_config,config): + config = dict(config) + with open(path_config, "w") as f: + yaml.dump(config, f) + +def to_json(path_params, path_json): + params = torch.load(path_params, map_location=torch.device('cpu')) + raw_state_dict = {} + for k, v in params.items(): + val = v.flatten().numpy().tolist() + raw_state_dict[k] = val + + with open(path_json, 'w') as outfile: + json.dump(raw_state_dict, outfile,indent= "\t") + + +def convert_tensor_to_numpy(tensor, is_squeeze=True): + if is_squeeze: + tensor = tensor.squeeze() + if tensor.requires_grad: + tensor = tensor.detach() + if tensor.is_cuda: + tensor = tensor.cpu() + return tensor.numpy() + + +def load_model( + expdir, + model, + optimizer, + name='model', + postfix='', + device='cpu'): + if postfix == '': + postfix = '_' + postfix + path = os.path.join(expdir, name+postfix) + path_pt = traverse_dir(expdir, ['pt'], is_ext=False) + global_step = 0 + if len(path_pt) > 0: + steps = [s[len(path):] for s in path_pt] + maxstep = max([int(s) if s.isdigit() else 0 for s in steps]) + if maxstep >= 0: + path_pt = path+str(maxstep)+'.pt' + else: + path_pt = path+'best.pt' + print(' [*] restoring model from', path_pt) + ckpt = torch.load(path_pt, map_location=torch.device(device)) + global_step = ckpt['global_step'] + model.load_state_dict(ckpt['model'], strict=False) + if ckpt.get("optimizer") is not None: + optimizer.load_state_dict(ckpt['optimizer']) + return global_step, model, optimizer diff --git a/AIMeiSheng/diffuse_fang/diffusion/onnx_export.py b/AIMeiSheng/diffuse_fang/diffusion/onnx_export.py new file mode 100644 index 0000000..6a4ea22 --- /dev/null +++ b/AIMeiSheng/diffuse_fang/diffusion/onnx_export.py @@ -0,0 +1,235 @@ +import os + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import yaml +from diffusion_onnx import GaussianDiffusion + + +class DotDict(dict): + def __getattr__(*args): + val = dict.get(*args) + return DotDict(val) if type(val) is dict else val + + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ + + +def load_model_vocoder( + model_path, + device='cpu'): + config_file = os.path.join(os.path.split(model_path)[0], 'config.yaml') + with open(config_file, "r") as config: + args = yaml.safe_load(config) + args = DotDict(args) + + # load model + model = Unit2Mel( + args.data.encoder_out_channels, + args.model.n_spk, + args.model.use_pitch_aug, + 128, + args.model.n_layers, + args.model.n_chans, + args.model.n_hidden, + args.model.timesteps, + args.model.k_step_max) + + print(' [Loading] ' + model_path) + ckpt = torch.load(model_path, map_location=torch.device(device)) + model.to(device) + model.load_state_dict(ckpt['model']) + model.eval() + return model, args + + +class Unit2Mel(nn.Module): + def __init__( + self, + input_channel, + n_spk, + use_pitch_aug=False, + out_dims=128, + n_layers=20, + n_chans=384, + n_hidden=256, + timesteps=1000, + k_step_max=1000): + super().__init__() + + self.unit_embed = nn.Linear(input_channel, n_hidden) + self.f0_embed = nn.Linear(1, n_hidden) + self.volume_embed = nn.Linear(1, n_hidden) + if use_pitch_aug: + self.aug_shift_embed = nn.Linear(1, n_hidden, bias=False) + else: + self.aug_shift_embed = None + self.n_spk = n_spk + if n_spk is not None and n_spk > 1: + self.spk_embed = nn.Embedding(n_spk, n_hidden) + + self.timesteps = timesteps if timesteps is not None else 1000 + self.k_step_max = k_step_max if k_step_max is not None and k_step_max>0 and k_step_max 1: # [N, S] * [S, B, 1, H] + g = g.reshape((g.shape[0], g.shape[1], 1, 1, 1)) # [N, S, B, 1, 1] + g = g * self.speaker_map # [N, S, B, 1, H] + g = torch.sum(g, dim=1) # [N, 1, B, 1, H] + g = g.transpose(0, -1).transpose(0, -2).squeeze(0) # [B, H, N] + x = x.transpose(1, 2) + g + return x + else: + return x.transpose(1, 2) + + + def init_spkembed(self, units, f0, volume, spk_id = None, spk_mix_dict = None, aug_shift = None, + gt_spec=None, infer=True, infer_speedup=10, method='dpm-solver', k_step=300, use_tqdm=True): + + ''' + input: + B x n_frames x n_unit + return: + dict of B x n_frames x feat + ''' + x = self.unit_embed(units) + self.f0_embed((1+ f0 / 700).log()) + self.volume_embed(volume) + if self.n_spk is not None and self.n_spk > 1: + if spk_mix_dict is not None: + spk_embed_mix = torch.zeros((1,1,self.hidden_size)) + for k, v in spk_mix_dict.items(): + spk_id_torch = torch.LongTensor(np.array([[k]])).to(units.device) + spk_embeddd = self.spk_embed(spk_id_torch) + self.speaker_map[k] = spk_embeddd + spk_embed_mix = spk_embed_mix + v * spk_embeddd + x = x + spk_embed_mix + else: + x = x + self.spk_embed(spk_id - 1) + self.speaker_map = self.speaker_map.unsqueeze(0) + self.speaker_map = self.speaker_map.detach() + return x.transpose(1, 2) + + def OnnxExport(self, project_name=None, init_noise=None, export_encoder=True, export_denoise=True, export_pred=True, export_after=True): + hubert_hidden_size = 768 + n_frames = 100 + hubert = torch.randn((1, n_frames, hubert_hidden_size)) + mel2ph = torch.arange(end=n_frames).unsqueeze(0).long() + f0 = torch.randn((1, n_frames)) + volume = torch.randn((1, n_frames)) + spk_mix = [] + spks = {} + if self.n_spk is not None and self.n_spk > 1: + for i in range(self.n_spk): + spk_mix.append(1.0/float(self.n_spk)) + spks.update({i:1.0/float(self.n_spk)}) + spk_mix = torch.tensor(spk_mix) + spk_mix = spk_mix.repeat(n_frames, 1) + self.init_spkembed(hubert, f0.unsqueeze(-1), volume.unsqueeze(-1), spk_mix_dict=spks) + self.forward(hubert, mel2ph, f0, volume, spk_mix) + if export_encoder: + torch.onnx.export( + self, + (hubert, mel2ph, f0, volume, spk_mix), + f"{project_name}_encoder.onnx", + input_names=["hubert", "mel2ph", "f0", "volume", "spk_mix"], + output_names=["mel_pred"], + dynamic_axes={ + "hubert": [1], + "f0": [1], + "volume": [1], + "mel2ph": [1], + "spk_mix": [0], + }, + opset_version=16 + ) + + self.decoder.OnnxExport(project_name, init_noise=init_noise, export_denoise=export_denoise, export_pred=export_pred, export_after=export_after) + + def ExportOnnx(self, project_name=None): + hubert_hidden_size = 768 + n_frames = 100 + hubert = torch.randn((1, n_frames, hubert_hidden_size)) + mel2ph = torch.arange(end=n_frames).unsqueeze(0).long() + f0 = torch.randn((1, n_frames)) + volume = torch.randn((1, n_frames)) + spk_mix = [] + spks = {} + if self.n_spk is not None and self.n_spk > 1: + for i in range(self.n_spk): + spk_mix.append(1.0/float(self.n_spk)) + spks.update({i:1.0/float(self.n_spk)}) + spk_mix = torch.tensor(spk_mix) + self.orgforward(hubert, f0.unsqueeze(-1), volume.unsqueeze(-1), spk_mix_dict=spks) + self.forward(hubert, mel2ph, f0, volume, spk_mix) + + torch.onnx.export( + self, + (hubert, mel2ph, f0, volume, spk_mix), + f"{project_name}_encoder.onnx", + input_names=["hubert", "mel2ph", "f0", "volume", "spk_mix"], + output_names=["mel_pred"], + dynamic_axes={ + "hubert": [1], + "f0": [1], + "volume": [1], + "mel2ph": [1] + }, + opset_version=16 + ) + + condition = torch.randn(1,self.decoder.n_hidden,n_frames) + noise = torch.randn((1, 1, self.decoder.mel_bins, condition.shape[2]), dtype=torch.float32) + pndm_speedup = torch.LongTensor([100]) + K_steps = torch.LongTensor([1000]) + self.decoder = torch.jit.script(self.decoder) + self.decoder(condition, noise, pndm_speedup, K_steps) + + torch.onnx.export( + self.decoder, + (condition, noise, pndm_speedup, K_steps), + f"{project_name}_diffusion.onnx", + input_names=["condition", "noise", "pndm_speedup", "K_steps"], + output_names=["mel"], + dynamic_axes={ + "condition": [2], + "noise": [3], + }, + opset_version=16 + ) + + +if __name__ == "__main__": + project_name = "dddsp" + model_path = f'{project_name}/model_500000.pt' + + model, _ = load_model_vocoder(model_path) + + # 分开Diffusion导出(需要使用MoeSS/MoeVoiceStudio或者自己编写Pndm/Dpm采样) + model.OnnxExport(project_name, export_encoder=True, export_denoise=True, export_pred=True, export_after=True) + + # 合并Diffusion导出(Encoder和Diffusion分开,直接将Encoder的结果和初始噪声输入Diffusion即可) + # model.ExportOnnx(project_name) + diff --git a/AIMeiSheng/diffuse_fang/diffusion/solver.py b/AIMeiSheng/diffuse_fang/diffusion/solver.py new file mode 100644 index 0000000..52657cc --- /dev/null +++ b/AIMeiSheng/diffuse_fang/diffusion/solver.py @@ -0,0 +1,200 @@ +import time + +import librosa +import numpy as np +import torch +from torch import autocast +from torch.cuda.amp import GradScaler + +from diffusion.logger import utils +from diffusion.logger.saver import Saver + + +def test(args, model, vocoder, loader_test, saver): + print(' [*] testing...') + model.eval() + + # losses + test_loss = 0. + + # intialization + num_batches = len(loader_test) + rtf_all = [] + + # run + with torch.no_grad(): + for bidx, data in enumerate(loader_test): + fn = data['name'][0].split("/")[-1] + speaker = data['name'][0].split("/")[-2] + print('--------') + print('{}/{} - {}'.format(bidx, num_batches, fn)) + + # unpack data + for k in data.keys(): + if not k.startswith('name'): + data[k] = data[k].to(args.device) + print('>>', data['name'][0]) + + # forward + st_time = time.time() + mel = model( + data['units'], + data['f0'], + data['volume'], + data['spk_id'], + gt_spec=None if model.k_step_max == model.timesteps else data['mel'], + infer=True, + infer_speedup=args.infer.speedup, + method=args.infer.method, + k_step=model.k_step_max + ) + signal = vocoder.infer(mel, data['f0']) + ed_time = time.time() + + # RTF + run_time = ed_time - st_time + song_time = signal.shape[-1] / args.data.sampling_rate + rtf = run_time / song_time + print('RTF: {} | {} / {}'.format(rtf, run_time, song_time)) + rtf_all.append(rtf) + + # loss + for i in range(args.train.batch_size): + loss = model( + data['units'], + data['f0'], + data['volume'], + data['spk_id'], + gt_spec=data['mel'], + infer=False, + k_step=model.k_step_max) + test_loss += loss.item() + + # log mel + saver.log_spec(f"{speaker}_{fn}.wav", data['mel'], mel) + + # log audi + path_audio = data['name_ext'][0] + audio, sr = librosa.load(path_audio, sr=args.data.sampling_rate) + if len(audio.shape) > 1: + audio = librosa.to_mono(audio) + audio = torch.from_numpy(audio).unsqueeze(0).to(signal) + saver.log_audio({f"{speaker}_{fn}_gt.wav": audio,f"{speaker}_{fn}_pred.wav": signal}) + # report + test_loss /= args.train.batch_size + test_loss /= num_batches + + # check + print(' [test_loss] test_loss:', test_loss) + print(' Real Time Factor', np.mean(rtf_all)) + return test_loss + + +def train(args, initial_global_step, model, optimizer, scheduler, vocoder, loader_train, loader_test): + # saver + saver = Saver(args, initial_global_step=initial_global_step) + + # model size + params_count = utils.get_network_paras_amount({'model': model}) + saver.log_info('--- model size ---') + saver.log_info(params_count) + + # run + num_batches = len(loader_train) + model.train() + saver.log_info('======= start training =======') + scaler = GradScaler() + if args.train.amp_dtype == 'fp32': + dtype = torch.float32 + elif args.train.amp_dtype == 'fp16': + dtype = torch.float16 + elif args.train.amp_dtype == 'bf16': + dtype = torch.bfloat16 + else: + raise ValueError(' [x] Unknown amp_dtype: ' + args.train.amp_dtype) + saver.log_info("epoch|batch_idx/num_batches|output_dir|batch/s|lr|time|step") + for epoch in range(args.train.epochs): + for batch_idx, data in enumerate(loader_train): + saver.global_step_increment() + optimizer.zero_grad() + + # unpack data + for k in data.keys(): + if not k.startswith('name'): + data[k] = data[k].to(args.device) + + # forward + if dtype == torch.float32: + loss = model(data['units'].float(), data['f0'], data['volume'], data['spk_id'], + aug_shift = data['aug_shift'], gt_spec=data['mel'].float(), infer=False, k_step=model.k_step_max) + else: + with autocast(device_type=args.device, dtype=dtype): + loss = model(data['units'], data['f0'], data['volume'], data['spk_id'], + aug_shift = data['aug_shift'], gt_spec=data['mel'], infer=False, k_step=model.k_step_max) + + # handle nan loss + if torch.isnan(loss): + raise ValueError(' [x] nan loss ') + else: + # backpropagate + if dtype == torch.float32: + loss.backward() + optimizer.step() + else: + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + scheduler.step() + + # log loss + if saver.global_step % args.train.interval_log == 0: + current_lr = optimizer.param_groups[0]['lr'] + saver.log_info( + 'epoch: {} | {:3d}/{:3d} | {} | batch/s: {:.2f} | lr: {:.6} | loss: {:.3f} | time: {} | step: {}'.format( + epoch, + batch_idx, + num_batches, + args.env.expdir, + args.train.interval_log/saver.get_interval_time(), + current_lr, + loss.item(), + saver.get_total_time(), + saver.global_step + ) + ) + + saver.log_value({ + 'train/loss': loss.item() + }) + + saver.log_value({ + 'train/lr': current_lr + }) + + # validation + if saver.global_step % args.train.interval_val == 0: + optimizer_save = optimizer if args.train.save_opt else None + + # save latest + saver.save_model(model, optimizer_save, postfix=f'{saver.global_step}') + last_val_step = saver.global_step - args.train.interval_val + if last_val_step % args.train.interval_force_save != 0: + saver.delete_model(postfix=f'{last_val_step}') + + # run testing set + test_loss = test(args, model, vocoder, loader_test, saver) + + # log loss + saver.log_info( + ' --- --- \nloss: {:.3f}. '.format( + test_loss, + ) + ) + + saver.log_value({ + 'validation/loss': test_loss + }) + + model.train() + + diff --git a/AIMeiSheng/diffuse_fang/diffusion/uni_pc.py b/AIMeiSheng/diffuse_fang/diffusion/uni_pc.py new file mode 100644 index 0000000..72d8f51 --- /dev/null +++ b/AIMeiSheng/diffuse_fang/diffusion/uni_pc.py @@ -0,0 +1,733 @@ +import math + +import torch + + +class NoiseScheduleVP: + def __init__( + self, + schedule='discrete', + betas=None, + alphas_cumprod=None, + continuous_beta_0=0.1, + continuous_beta_1=20., + dtype=torch.float32, + ): + """Create a wrapper class for the forward SDE (VP type). + *** + Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t. + We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images. + *** + The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ). + We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper). + Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have: + log_alpha_t = self.marginal_log_mean_coeff(t) + sigma_t = self.marginal_std(t) + lambda_t = self.marginal_lambda(t) + Moreover, as lambda(t) is an invertible function, we also support its inverse function: + t = self.inverse_lambda(lambda_t) + =============================================================== + We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]). + 1. For discrete-time DPMs: + For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by: + t_i = (i + 1) / N + e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1. + We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3. + Args: + betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details) + alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details) + Note that we always have alphas_cumprod = cumprod(1 - betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`. + **Important**: Please pay special attention for the args for `alphas_cumprod`: + The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that + q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ). + Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have + alpha_{t_n} = \sqrt{\hat{alpha_n}}, + and + log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}). + 2. For continuous-time DPMs: + We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise + schedule are the default settings in DDPM and improved-DDPM: + Args: + beta_min: A `float` number. The smallest beta for the linear schedule. + beta_max: A `float` number. The largest beta for the linear schedule. + cosine_s: A `float` number. The hyperparameter in the cosine schedule. + cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule. + T: A `float` number. The ending time of the forward process. + =============================================================== + Args: + schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs, + 'linear' or 'cosine' for continuous-time DPMs. + Returns: + A wrapper object of the forward SDE (VP type). + + =============================================================== + Example: + # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', betas=betas) + # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) + # For continuous-time DPMs (VPSDE), linear schedule: + >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) + """ + + if schedule not in ['discrete', 'linear', 'cosine']: + raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule)) + + self.schedule = schedule + if schedule == 'discrete': + if betas is not None: + log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) + else: + assert alphas_cumprod is not None + log_alphas = 0.5 * torch.log(alphas_cumprod) + self.total_N = len(log_alphas) + self.T = 1. + self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype) + self.log_alpha_array = log_alphas.reshape((1, -1,)).to(dtype=dtype) + else: + self.total_N = 1000 + self.beta_0 = continuous_beta_0 + self.beta_1 = continuous_beta_1 + self.cosine_s = 0.008 + self.cosine_beta_max = 999. + self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s + self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.)) + self.schedule = schedule + if schedule == 'cosine': + # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T. + # Note that T = 0.9946 may be not the optimal setting. However, we find it works well. + self.T = 0.9946 + else: + self.T = 1. + + def marginal_log_mean_coeff(self, t): + """ + Compute log(alpha_t) of a given continuous-time label t in [0, T]. + """ + if self.schedule == 'discrete': + return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1)) + elif self.schedule == 'linear': + return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 + elif self.schedule == 'cosine': + def log_alpha_fn(s): + return torch.log(torch.cos((s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0)) + log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 + return log_alpha_t + + def marginal_alpha(self, t): + """ + Compute alpha_t of a given continuous-time label t in [0, T]. + """ + return torch.exp(self.marginal_log_mean_coeff(t)) + + def marginal_std(self, t): + """ + Compute sigma_t of a given continuous-time label t in [0, T]. + """ + return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t))) + + def marginal_lambda(self, t): + """ + Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. + """ + log_mean_coeff = self.marginal_log_mean_coeff(t) + log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) + return log_mean_coeff - log_std + + def inverse_lambda(self, lamb): + """ + Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. + """ + if self.schedule == 'linear': + tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + Delta = self.beta_0**2 + tmp + return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) + elif self.schedule == 'discrete': + log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb) + t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1])) + return t.reshape((-1,)) + else: + log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + def t_fn(log_alpha_t): + return torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2.0 * (1.0 + self.cosine_s) / math.pi - self.cosine_s + t = t_fn(log_alpha) + return t + + +def model_wrapper( + model, + noise_schedule, + model_type="noise", + model_kwargs={}, + guidance_type="uncond", + condition=None, + unconditional_condition=None, + guidance_scale=1., + classifier_fn=None, + classifier_kwargs={}, +): + """Create a wrapper function for the noise prediction model. + """ + + def get_model_input_time(t_continuous): + """ + Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. + For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. + For continuous-time DPMs, we just use `t_continuous`. + """ + if noise_schedule.schedule == 'discrete': + return (t_continuous - 1. / noise_schedule.total_N) * noise_schedule.total_N + else: + return t_continuous + + def noise_pred_fn(x, t_continuous, cond=None): + t_input = get_model_input_time(t_continuous) + if cond is None: + output = model(x, t_input, **model_kwargs) + else: + output = model(x, t_input, cond, **model_kwargs) + if model_type == "noise": + return output + elif model_type == "x_start": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + return (x - alpha_t * output) / sigma_t + elif model_type == "v": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + return alpha_t * output + sigma_t * x + elif model_type == "score": + sigma_t = noise_schedule.marginal_std(t_continuous) + return -sigma_t * output + + def cond_grad_fn(x, t_input): + """ + Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). + """ + with torch.enable_grad(): + x_in = x.detach().requires_grad_(True) + log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs) + return torch.autograd.grad(log_prob.sum(), x_in)[0] + + def model_fn(x, t_continuous): + """ + The noise predicition model function that is used for DPM-Solver. + """ + if guidance_type == "uncond": + return noise_pred_fn(x, t_continuous) + elif guidance_type == "classifier": + assert classifier_fn is not None + t_input = get_model_input_time(t_continuous) + cond_grad = cond_grad_fn(x, t_input) + sigma_t = noise_schedule.marginal_std(t_continuous) + noise = noise_pred_fn(x, t_continuous) + return noise - guidance_scale * sigma_t * cond_grad + elif guidance_type == "classifier-free": + if guidance_scale == 1. or unconditional_condition is None: + return noise_pred_fn(x, t_continuous, cond=condition) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t_continuous] * 2) + c_in = torch.cat([unconditional_condition, condition]) + noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) + return noise_uncond + guidance_scale * (noise - noise_uncond) + + assert model_type in ["noise", "x_start", "v"] + assert guidance_type in ["uncond", "classifier", "classifier-free"] + return model_fn + + +class UniPC: + def __init__( + self, + model_fn, + noise_schedule, + algorithm_type="data_prediction", + correcting_x0_fn=None, + correcting_xt_fn=None, + thresholding_max_val=1., + dynamic_thresholding_ratio=0.995, + variant='bh1' + ): + """Construct a UniPC. + + We support both data_prediction and noise_prediction. + """ + self.model = lambda x, t: model_fn(x, t.expand((x.shape[0]))) + self.noise_schedule = noise_schedule + assert algorithm_type in ["data_prediction", "noise_prediction"] + + if correcting_x0_fn == "dynamic_thresholding": + self.correcting_x0_fn = self.dynamic_thresholding_fn + else: + self.correcting_x0_fn = correcting_x0_fn + + self.correcting_xt_fn = correcting_xt_fn + self.dynamic_thresholding_ratio = dynamic_thresholding_ratio + self.thresholding_max_val = thresholding_max_val + + self.variant = variant + self.predict_x0 = algorithm_type == "data_prediction" + + def dynamic_thresholding_fn(self, x0, t=None): + """ + The dynamic thresholding method. + """ + dims = x0.dim() + p = self.dynamic_thresholding_ratio + s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) + s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims) + x0 = torch.clamp(x0, -s, s) / s + return x0 + + def noise_prediction_fn(self, x, t): + """ + Return the noise prediction model. + """ + return self.model(x, t) + + def data_prediction_fn(self, x, t): + """ + Return the data prediction model (with corrector). + """ + noise = self.noise_prediction_fn(x, t) + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + x0 = (x - sigma_t * noise) / alpha_t + if self.correcting_x0_fn is not None: + x0 = self.correcting_x0_fn(x0) + return x0 + + def model_fn(self, x, t): + """ + Convert the model to the noise prediction model or the data prediction model. + """ + if self.predict_x0: + return self.data_prediction_fn(x, t) + else: + return self.noise_prediction_fn(x, t) + + def get_time_steps(self, skip_type, t_T, t_0, N, device): + """Compute the intermediate time steps for sampling. + """ + if skip_type == 'logSNR': + lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) + lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) + logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) + return self.noise_schedule.inverse_lambda(logSNR_steps) + elif skip_type == 'time_uniform': + return torch.linspace(t_T, t_0, N + 1).to(device) + elif skip_type == 'time_quadratic': + t_order = 2 + t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device) + return t + else: + raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) + + def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): + """ + Get the order of each step for sampling by the singlestep DPM-Solver. + """ + if order == 3: + K = steps // 3 + 1 + if steps % 3 == 0: + orders = [3,] * (K - 2) + [2, 1] + elif steps % 3 == 1: + orders = [3,] * (K - 1) + [1] + else: + orders = [3,] * (K - 1) + [2] + elif order == 2: + if steps % 2 == 0: + K = steps // 2 + orders = [2,] * K + else: + K = steps // 2 + 1 + orders = [2,] * (K - 1) + [1] + elif order == 1: + K = steps + orders = [1,] * steps + else: + raise ValueError("'order' must be '1' or '2' or '3'.") + if skip_type == 'logSNR': + # To reproduce the results in DPM-Solver paper + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) + else: + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders), 0).to(device)] + return timesteps_outer, orders + + def denoise_to_zero_fn(self, x, s): + """ + Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization. + """ + return self.data_prediction_fn(x, s) + + def multistep_uni_pc_update(self, x, model_prev_list, t_prev_list, t, order, **kwargs): + if len(t.shape) == 0: + t = t.view(-1) + if 'bh' in self.variant: + return self.multistep_uni_pc_bh_update(x, model_prev_list, t_prev_list, t, order, **kwargs) + else: + assert self.variant == 'vary_coeff' + return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs) + + def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True): + #print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)') + ns = self.noise_schedule + assert order <= len(model_prev_list) + + # first compute rks + t_prev_0 = t_prev_list[-1] + lambda_prev_0 = ns.marginal_lambda(t_prev_0) + lambda_t = ns.marginal_lambda(t) + model_prev_0 = model_prev_list[-1] + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + log_alpha_t = ns.marginal_log_mean_coeff(t) + alpha_t = torch.exp(log_alpha_t) + + h = lambda_t - lambda_prev_0 + + rks = [] + D1s = [] + for i in range(1, order): + t_prev_i = t_prev_list[-(i + 1)] + model_prev_i = model_prev_list[-(i + 1)] + lambda_prev_i = ns.marginal_lambda(t_prev_i) + rk = (lambda_prev_i - lambda_prev_0) / h + rks.append(rk) + D1s.append((model_prev_i - model_prev_0) / rk) + + rks.append(1.) + rks = torch.tensor(rks, device=x.device) + + K = len(rks) + # build C matrix + C = [] + + col = torch.ones_like(rks) + for k in range(1, K + 1): + C.append(col) + col = col * rks / (k + 1) + C = torch.stack(C, dim=1) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) # (B, K) + C_inv_p = torch.linalg.inv(C[:-1, :-1]) + A_p = C_inv_p + + if use_corrector: + #print('using corrector') + C_inv = torch.linalg.inv(C) + A_c = C_inv + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) + h_phi_ks = [] + factorial_k = 1 + h_phi_k = h_phi_1 + for k in range(1, K + 2): + h_phi_ks.append(h_phi_k) + h_phi_k = h_phi_k / hh - 1 / factorial_k + factorial_k *= (k + 1) + + model_t = None + if self.predict_x0: + x_t_ = ( + sigma_t / sigma_prev_0 * x + - alpha_t * h_phi_1 * model_prev_0 + ) + # now predictor + x_t = x_t_ + if len(D1s) > 0: + # compute the residuals for predictor + for k in range(K - 1): + x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k]) + # now corrector + if use_corrector: + model_t = self.model_fn(x_t, t) + D1_t = (model_t - model_prev_0) + x_t = x_t_ + k = 0 + for k in range(K - 1): + x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1]) + x_t = x_t - alpha_t * h_phi_ks[K] * (D1_t * A_c[k][-1]) + else: + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + x_t_ = ( + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * h_phi_1) * model_prev_0 + ) + # now predictor + x_t = x_t_ + if len(D1s) > 0: + # compute the residuals for predictor + for k in range(K - 1): + x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k]) + # now corrector + if use_corrector: + model_t = self.model_fn(x_t, t) + D1_t = (model_t - model_prev_0) + x_t = x_t_ + k = 0 + for k in range(K - 1): + x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1]) + x_t = x_t - sigma_t * h_phi_ks[K] * (D1_t * A_c[k][-1]) + return x_t, model_t + + def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True): + #print(f'using unified predictor-corrector with order {order} (solver type: B(h))') + ns = self.noise_schedule + assert order <= len(model_prev_list) + + # first compute rks + t_prev_0 = t_prev_list[-1] + lambda_prev_0 = ns.marginal_lambda(t_prev_0) + lambda_t = ns.marginal_lambda(t) + model_prev_0 = model_prev_list[-1] + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + alpha_t = torch.exp(log_alpha_t) + + h = lambda_t - lambda_prev_0 + + rks = [] + D1s = [] + for i in range(1, order): + t_prev_i = t_prev_list[-(i + 1)] + model_prev_i = model_prev_list[-(i + 1)] + lambda_prev_i = ns.marginal_lambda(t_prev_i) + rk = (lambda_prev_i - lambda_prev_0) / h + rks.append(rk) + D1s.append((model_prev_i - model_prev_0) / rk) + + rks.append(1.) + rks = torch.tensor(rks, device=x.device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.variant == 'bh1': + B_h = hh + elif self.variant == 'bh2': + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= (i + 1) + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.cat(b) + + # now predictor + use_predictor = len(D1s) > 0 and x_t is None + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) # (B, K) + if x_t is None: + # for order 2, we use a simplified version + if order == 2: + rhos_p = torch.tensor([0.5], device=b.device) + else: + rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]) + else: + D1s = None + + if use_corrector: + #print('using corrector') + # for order 1, we use a simplified version + if order == 1: + rhos_c = torch.tensor([0.5], device=b.device) + else: + rhos_c = torch.linalg.solve(R, b) + + model_t = None + if self.predict_x0: + x_t_ = ( + sigma_t / sigma_prev_0 * x + - alpha_t * h_phi_1 * model_prev_0 + ) + + if x_t is None: + if use_predictor: + pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s) + else: + pred_res = 0 + x_t = x_t_ - alpha_t * B_h * pred_res + + if use_corrector: + model_t = self.model_fn(x_t, t) + if D1s is not None: + corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = (model_t - model_prev_0) + x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t) + else: + x_t_ = ( + torch.exp(log_alpha_t - log_alpha_prev_0) * x + - sigma_t * h_phi_1 * model_prev_0 + ) + if x_t is None: + if use_predictor: + pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s) + else: + pred_res = 0 + x_t = x_t_ - sigma_t * B_h * pred_res + + if use_corrector: + model_t = self.model_fn(x_t, t) + if D1s is not None: + corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = (model_t - model_prev_0) + x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t) + return x_t, model_t + + def sample(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform', + method='multistep', lower_order_final=True, denoise_to_zero=False, atol=0.0078, rtol=0.05, return_intermediate=False, + ): + """ + Compute the sample at time `t_end` by UniPC, given the initial `x` at time `t_start`. + """ + t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end + t_T = self.noise_schedule.T if t_start is None else t_start + assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array" + if return_intermediate: + assert method in ['multistep', 'singlestep', 'singlestep_fixed'], "Cannot use adaptive solver when saving intermediate values" + if self.correcting_xt_fn is not None: + assert method in ['multistep', 'singlestep', 'singlestep_fixed'], "Cannot use adaptive solver when correcting_xt_fn is not None" + device = x.device + intermediates = [] + with torch.no_grad(): + if method == 'multistep': + assert steps >= order + timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) + assert timesteps.shape[0] - 1 == steps + # Init the initial values. + step = 0 + t = timesteps[step] + t_prev_list = [t] + model_prev_list = [self.model_fn(x, t)] + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + + # Init the first `order` values by lower order multistep UniPC. + for step in range(1, order): + t = timesteps[step] + x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, t, step, use_corrector=True) + if model_x is None: + model_x = self.model_fn(x, t) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + t_prev_list.append(t) + model_prev_list.append(model_x) + + # Compute the remaining values by `order`-th order multistep DPM-Solver. + for step in range(order, steps + 1): + t = timesteps[step] + if lower_order_final: + step_order = min(order, steps + 1 - step) + else: + step_order = order + if step == steps: + #print('do not run corrector at the last step') + use_corrector = False + else: + use_corrector = True + x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, t, step_order, use_corrector=use_corrector) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + for i in range(order - 1): + t_prev_list[i] = t_prev_list[i + 1] + model_prev_list[i] = model_prev_list[i + 1] + t_prev_list[-1] = t + # We do not need to evaluate the final model value. + if step < steps: + if model_x is None: + model_x = self.model_fn(x, t) + model_prev_list[-1] = model_x + else: + raise ValueError("Got wrong method {}".format(method)) + + if denoise_to_zero: + t = torch.ones((1,)).to(device) * t_0 + x = self.denoise_to_zero_fn(x, t) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step + 1) + if return_intermediate: + intermediates.append(x) + if return_intermediate: + return x, intermediates + else: + return x + + +############################################################# +# other utility functions +############################################################# + +def interpolate_fn(x, xp, yp): + """ + A piecewise linear function y = f(x), using xp and yp as keypoints. + We implement f(x) in a differentiable way (i.e. applicable for autograd). + The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.) + + Args: + x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver). + xp: PyTorch tensor with shape [C, K], where K is the number of keypoints. + yp: PyTorch tensor with shape [C, K]. + Returns: + The function values f(x), with shape [N, C]. + """ + N, K = x.shape[0], xp.shape[1] + all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) + sorted_all_x, x_indices = torch.sort(all_x, dim=2) + x_idx = torch.argmin(x_indices, dim=2) + cand_start_idx = x_idx - 1 + start_idx = torch.where( + torch.eq(x_idx, 0), + torch.tensor(1, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) + start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) + end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) + start_idx2 = torch.where( + torch.eq(x_idx, 0), + torch.tensor(0, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) + start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) + end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) + cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) + return cand + + +def expand_dims(v, dims): + """ + Expand the tensor `v` to the dim `dims`. + + Args: + `v`: a PyTorch tensor with shape [N]. + `dim`: a `int`. + Returns: + a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. + """ + return v[(...,) + (None,)*(dims - 1)] \ No newline at end of file diff --git a/AIMeiSheng/diffuse_fang/diffusion/unit2mel.py b/AIMeiSheng/diffuse_fang/diffusion/unit2mel.py new file mode 100644 index 0000000..5087f2a --- /dev/null +++ b/AIMeiSheng/diffuse_fang/diffusion/unit2mel.py @@ -0,0 +1,167 @@ +import os + +import numpy as np +import torch +import torch.nn as nn +import yaml + +from .diffusion import GaussianDiffusion +from .vocoder import Vocoder +from .wavenet import WaveNet + + +class DotDict(dict): + def __getattr__(*args): + val = dict.get(*args) + return DotDict(val) if type(val) is dict else val + + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ + + +def load_model_vocoder( + model_path, + device='cpu', + config_path = None + ): + if config_path is None: + config_file = os.path.join(os.path.split(model_path)[0], 'config.yaml') + else: + config_file = config_path + + with open(config_file, "r") as config: + args = yaml.safe_load(config) + args = DotDict(args) + + # load vocoder + vocoder = Vocoder(args.vocoder.type, args.vocoder.ckpt, device=device) + + # load model + model = Unit2Mel( + args.data.encoder_out_channels, + args.model.n_spk, + args.model.use_pitch_aug, + vocoder.dimension, + args.model.n_layers, + args.model.n_chans, + args.model.n_hidden, + args.model.timesteps, + args.model.k_step_max + ) + + print(' [Loading] ' + model_path) + ckpt = torch.load(model_path, map_location=torch.device(device)) + model.to(device) + model.load_state_dict(ckpt['model']) + model.eval() + print(f'Loaded diffusion model, sampler is {args.infer.method}, speedup: {args.infer.speedup} ') + return model, vocoder, args + + +class Unit2Mel(nn.Module): + def __init__( + self, + input_channel, + n_spk, + use_pitch_aug=False, + out_dims=128, + n_layers=20, + n_chans=384, + n_hidden=256, + timesteps=1000, + k_step_max=1000 + ): + super().__init__() + self.unit_embed = nn.Linear(input_channel, n_hidden) + self.f0_embed = nn.Linear(1, n_hidden) + self.volume_embed = nn.Linear(1, n_hidden) + if use_pitch_aug: + self.aug_shift_embed = nn.Linear(1, n_hidden, bias=False) + else: + self.aug_shift_embed = None + self.n_spk = n_spk + if n_spk is not None and n_spk > 1: + self.spk_embed = nn.Embedding(n_spk, n_hidden) + + self.timesteps = timesteps if timesteps is not None else 1000 + self.k_step_max = k_step_max if k_step_max is not None and k_step_max>0 and k_step_max 1: + if spk_mix_dict is not None: + spk_embed_mix = torch.zeros((1,1,self.hidden_size)) + for k, v in spk_mix_dict.items(): + spk_id_torch = torch.LongTensor(np.array([[k]])).to(units.device) + spk_embeddd = self.spk_embed(spk_id_torch) + self.speaker_map[k] = spk_embeddd + spk_embed_mix = spk_embed_mix + v * spk_embeddd + x = x + spk_embed_mix + else: + x = x + self.spk_embed(spk_id - 1) + self.speaker_map = self.speaker_map.unsqueeze(0) + self.speaker_map = self.speaker_map.detach() + return x.transpose(1, 2) + + def init_spkmix(self, n_spk): + self.speaker_map = torch.zeros((n_spk,1,1,self.n_hidden)) + hubert_hidden_size = self.input_channel + n_frames = 10 + hubert = torch.randn((1, n_frames, hubert_hidden_size)) + f0 = torch.randn((1, n_frames)) + volume = torch.randn((1, n_frames)) + spks = {} + for i in range(n_spk): + spks.update({i:1.0/float(self.n_spk)}) + self.init_spkembed(hubert, f0.unsqueeze(-1), volume.unsqueeze(-1), spk_mix_dict=spks) + + def forward(self, units, f0, volume, spk_id = None, spk_mix_dict = None, aug_shift = None, + gt_spec=None, infer=True, infer_speedup=10, method='dpm-solver', k_step=300, use_tqdm=True): + + ''' + input: + B x n_frames x n_unit + return: + dict of B x n_frames x feat + ''' + + if not self.training and gt_spec is not None and k_step>self.k_step_max: + raise Exception("The shallow diffusion k_step is greater than the maximum diffusion k_step(k_step_max)!") + + if not self.training and gt_spec is None and self.k_step_max!=self.timesteps: + raise Exception("This model can only be used for shallow diffusion and can not infer alone!") + + x = self.unit_embed(units) + self.f0_embed((1+ f0 / 700).log()) + self.volume_embed(volume) + if self.n_spk is not None and self.n_spk > 1: + if spk_mix_dict is not None: + for k, v in spk_mix_dict.items(): + spk_id_torch = torch.LongTensor(np.array([[k]])).to(units.device) + x = x + v * self.spk_embed(spk_id_torch) + else: + if spk_id.shape[1] > 1: + g = spk_id.reshape((spk_id.shape[0], spk_id.shape[1], 1, 1, 1)) # [N, S, B, 1, 1] + g = g * self.speaker_map # [N, S, B, 1, H] + g = torch.sum(g, dim=1) # [N, 1, B, 1, H] + g = g.transpose(0, -1).transpose(0, -2).squeeze(0) # [B, H, N] + x = x + g + else: + x = x + self.spk_embed(spk_id) + if self.aug_shift_embed is not None and aug_shift is not None: + x = x + self.aug_shift_embed(aug_shift / 5) + x = self.decoder(x, gt_spec=gt_spec, infer=infer, infer_speedup=infer_speedup, method=method, k_step=k_step, use_tqdm=use_tqdm) + + return x + diff --git a/AIMeiSheng/diffuse_fang/diffusion/vocoder.py b/AIMeiSheng/diffuse_fang/diffusion/vocoder.py new file mode 100644 index 0000000..ec9c80e --- /dev/null +++ b/AIMeiSheng/diffuse_fang/diffusion/vocoder.py @@ -0,0 +1,95 @@ +import torch +from torchaudio.transforms import Resample + +from vdecoder.nsf_hifigan.models import load_config, load_model +from vdecoder.nsf_hifigan.nvSTFT import STFT + + +class Vocoder: + def __init__(self, vocoder_type, vocoder_ckpt, device = None): + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.device = device + + if vocoder_type == 'nsf-hifigan': + self.vocoder = NsfHifiGAN(vocoder_ckpt, device = device) + elif vocoder_type == 'nsf-hifigan-log10': + self.vocoder = NsfHifiGANLog10(vocoder_ckpt, device = device) + else: + raise ValueError(f" [x] Unknown vocoder: {vocoder_type}") + + self.resample_kernel = {} + self.vocoder_sample_rate = self.vocoder.sample_rate() + self.vocoder_hop_size = self.vocoder.hop_size() + self.dimension = self.vocoder.dimension() + + def extract(self, audio, sample_rate, keyshift=0): + + # resample + if sample_rate == self.vocoder_sample_rate: + audio_res = audio + else: + key_str = str(sample_rate) + if key_str not in self.resample_kernel: + self.resample_kernel[key_str] = Resample(sample_rate, self.vocoder_sample_rate, lowpass_filter_width = 128).to(self.device) + audio_res = self.resample_kernel[key_str](audio) + + # extract + mel = self.vocoder.extract(audio_res, keyshift=keyshift) # B, n_frames, bins + return mel + + def infer(self, mel, f0): + f0 = f0[:,:mel.size(1),0] # B, n_frames + audio = self.vocoder(mel, f0) + return audio + + +class NsfHifiGAN(torch.nn.Module): + def __init__(self, model_path, device=None): + super().__init__() + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.device = device + self.model_path = model_path + self.model = None + self.h = load_config(model_path) + self.stft = STFT( + self.h.sampling_rate, + self.h.num_mels, + self.h.n_fft, + self.h.win_size, + self.h.hop_size, + self.h.fmin, + self.h.fmax) + + def sample_rate(self): + return self.h.sampling_rate + + def hop_size(self): + return self.h.hop_size + + def dimension(self): + return self.h.num_mels + + def extract(self, audio, keyshift=0): + mel = self.stft.get_mel(audio, keyshift=keyshift).transpose(1, 2) # B, n_frames, bins + return mel + + def forward(self, mel, f0): + if self.model is None: + print('| Load HifiGAN: ', self.model_path) + self.model, self.h = load_model(self.model_path, device=self.device) + with torch.no_grad(): + c = mel.transpose(1, 2) + audio = self.model(c, f0) + return audio + +class NsfHifiGANLog10(NsfHifiGAN): + def forward(self, mel, f0): + if self.model is None: + print('| Load HifiGAN: ', self.model_path) + self.model, self.h = load_model(self.model_path, device=self.device) + with torch.no_grad(): + c = 0.434294 * mel.transpose(1, 2) + audio = self.model(c, f0) + return audio \ No newline at end of file diff --git a/AIMeiSheng/diffuse_fang/diffusion/wavenet.py b/AIMeiSheng/diffuse_fang/diffusion/wavenet.py new file mode 100644 index 0000000..30404d3 --- /dev/null +++ b/AIMeiSheng/diffuse_fang/diffusion/wavenet.py @@ -0,0 +1,110 @@ +import math +from math import sqrt + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Mish + + +class Conv1d(torch.nn.Conv1d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + nn.init.kaiming_normal_(self.weight) + + +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class ResidualBlock(nn.Module): + def __init__(self, encoder_hidden, residual_channels, dilation): + super().__init__() + self.residual_channels = residual_channels + self.dilated_conv = nn.Conv1d( + residual_channels, + 2 * residual_channels, + kernel_size=3, + padding=dilation, + dilation=dilation + ) + self.diffusion_projection = nn.Linear(residual_channels, residual_channels) + self.conditioner_projection = nn.Conv1d(encoder_hidden, 2 * residual_channels, 1) + self.output_projection = nn.Conv1d(residual_channels, 2 * residual_channels, 1) + + def forward(self, x, conditioner, diffusion_step): + diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1) + conditioner = self.conditioner_projection(conditioner) + y = x + diffusion_step + + y = self.dilated_conv(y) + conditioner + + # Using torch.split instead of torch.chunk to avoid using onnx::Slice + gate, filter = torch.split(y, [self.residual_channels, self.residual_channels], dim=1) + y = torch.sigmoid(gate) * torch.tanh(filter) + + y = self.output_projection(y) + + # Using torch.split instead of torch.chunk to avoid using onnx::Slice + residual, skip = torch.split(y, [self.residual_channels, self.residual_channels], dim=1) + return (x + residual) / math.sqrt(2.0), skip + + +class WaveNet(nn.Module): + def __init__(self, in_dims=128, n_layers=20, n_chans=384, n_hidden=256): + super().__init__() + self.input_projection = Conv1d(in_dims, n_chans, 1) + self.diffusion_embedding = SinusoidalPosEmb(n_chans) + self.mlp = nn.Sequential( + nn.Linear(n_chans, n_chans * 4), + Mish(), + nn.Linear(n_chans * 4, n_chans) + ) + self.residual_layers = nn.ModuleList([ + ResidualBlock( + encoder_hidden=n_hidden, + residual_channels=n_chans, + dilation=1 + ) + for i in range(n_layers) + ]) + self.skip_projection = Conv1d(n_chans, n_chans, 1) + self.output_projection = Conv1d(n_chans, in_dims, 1) + nn.init.zeros_(self.output_projection.weight) + + def forward(self, spec, diffusion_step, cond): + """ + :param spec: [B, 1, M, T] + :param diffusion_step: [B, 1] + :param cond: [B, M, T] + :return: + """ + x = spec.squeeze(1) + #x = x.half() #fang add + x = self.input_projection(x) # [B, residual_channel, T] + + x = F.relu(x) + diffusion_step = self.diffusion_embedding(diffusion_step) + #diffusion_step = diffusion_step.half() #fangadd + diffusion_step = self.mlp(diffusion_step) + skip = [] + for layer in self.residual_layers: + x, skip_connection = layer(x, cond, diffusion_step) + skip.append(skip_connection) + + x = torch.sum(torch.stack(skip), dim=0) / sqrt(len(self.residual_layers)) + x = self.skip_projection(x) + x = F.relu(x) + x = self.output_projection(x) # [B, mel_bins, T] + return x[:, None, :, :] diff --git a/AIMeiSheng/lib/infer_pack/__pycache__/attentions_in_dec.cpython-38.pyc b/AIMeiSheng/lib/infer_pack/__pycache__/attentions_in_dec.cpython-38.pyc deleted file mode 100644 index e398b08..0000000 Binary files a/AIMeiSheng/lib/infer_pack/__pycache__/attentions_in_dec.cpython-38.pyc and /dev/null differ diff --git a/AIMeiSheng/lib/infer_pack/__pycache__/commons.cpython-38.pyc b/AIMeiSheng/lib/infer_pack/__pycache__/commons.cpython-38.pyc deleted file mode 100644 index f4bfdd8..0000000 Binary files a/AIMeiSheng/lib/infer_pack/__pycache__/commons.cpython-38.pyc and /dev/null differ diff --git a/AIMeiSheng/lib/infer_pack/__pycache__/models_embed_in_dec_diff_fi.cpython-38.pyc b/AIMeiSheng/lib/infer_pack/__pycache__/models_embed_in_dec_diff_fi.cpython-38.pyc deleted file mode 100644 index 98fa5b1..0000000 Binary files a/AIMeiSheng/lib/infer_pack/__pycache__/models_embed_in_dec_diff_fi.cpython-38.pyc and /dev/null differ diff --git a/AIMeiSheng/lib/infer_pack/__pycache__/modules.cpython-38.pyc b/AIMeiSheng/lib/infer_pack/__pycache__/modules.cpython-38.pyc deleted file mode 100644 index 42763e2..0000000 Binary files a/AIMeiSheng/lib/infer_pack/__pycache__/modules.cpython-38.pyc and /dev/null differ diff --git a/AIMeiSheng/lib/infer_pack/__pycache__/transforms.cpython-38.pyc b/AIMeiSheng/lib/infer_pack/__pycache__/transforms.cpython-38.pyc deleted file mode 100644 index 088ec9e..0000000 Binary files a/AIMeiSheng/lib/infer_pack/__pycache__/transforms.cpython-38.pyc and /dev/null differ diff --git a/AIMeiSheng/lib/infer_pack/models_embed_in_dec_diff_control_enc.py b/AIMeiSheng/lib/infer_pack/models_embed_in_dec_diff_control_enc.py new file mode 100644 index 0000000..afd1f8a --- /dev/null +++ b/AIMeiSheng/lib/infer_pack/models_embed_in_dec_diff_control_enc.py @@ -0,0 +1,1275 @@ +import math, pdb, os +from time import time as ttime +import torch +from torch import nn +from torch.nn import functional as F +from lib.infer_pack import modules +from lib.infer_pack import attentions_in_dec as attentions +from lib.infer_pack import commons +from lib.infer_pack.commons import init_weights, get_padding +from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d +from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm +from lib.infer_pack.commons import init_weights +import numpy as np +from lib.infer_pack import commons +from thop import profile +from diffuse_fang.diffUse_wraper import diff_decoder,ddpm_para +ddpm_dp = ddpm_para() + +class TextEncoder256(nn.Module): + def __init__( + self, + out_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + f0=True, + ): + super().__init__() + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.emb_phone = nn.Linear(256, hidden_channels) + self.lrelu = nn.LeakyReLU(0.1, inplace=True) + if f0 == True: + self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256 + self.encoder = attentions.Encoder( + hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout + ) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, phone, pitch, lengths): + if pitch == None: + x = self.emb_phone(phone) + else: + x = self.emb_phone(phone) + self.emb_pitch(pitch) + x = x * math.sqrt(self.hidden_channels) # [b, t, h] + x = self.lrelu(x) + x = torch.transpose(x, 1, -1) # [b, h, t] + x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to( + x.dtype + ) + x = self.encoder(x * x_mask, x_mask) + stats = self.proj(x) * x_mask + + m, logs = torch.split(stats, self.out_channels, dim=1) + return m, logs, x_mask + + +class TextEncoder768(nn.Module): + def __init__( + self, + out_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + f0=True, + ): + super().__init__() + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.emb_phone = nn.Linear(768, hidden_channels) + self.lrelu = nn.LeakyReLU(0.1, inplace=True) + if f0 == True: + self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256 + self.encoder = attentions.Encoder( + hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout + ) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + #self.emb_g = nn.Linear(256, hidden_channels) + + def forward(self, phone, pitch, lengths,g):#fang add + if pitch == None: + x = self.emb_phone(phone) + else: + x = self.emb_phone(phone) + self.emb_pitch(pitch) #+ self.emb_g(g) + #print("@@@x:",x.shape) + x = x * math.sqrt(self.hidden_channels) # [b, t, h] + x = self.lrelu(x) + x = torch.transpose(x, 1, -1) # [b, h, t] + #print("@@@x1:",x.shape) + x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to( + x.dtype + ) + #x = self.encoder(x * x_mask, x_mask,g) + x = self.encoder(x * x_mask, x_mask,g)#fang add + stats = self.proj(x) * x_mask + + m, logs = torch.split(stats, self.out_channels, dim=1) + return m, logs, x_mask,x + + +class ResidualCouplingBlock(nn.Module): + def __init__( + self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + n_flows=4, + gin_channels=0, + ): + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.flows = nn.ModuleList() + for i in range(n_flows): + self.flows.append( + modules.ResidualCouplingLayer( + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=gin_channels, + mean_only=True, + ) + ) + self.flows.append(modules.Flip()) + + def forward(self, x, x_mask, g=None, reverse=False): + if not reverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, reverse=reverse) + else: + for flow in reversed(self.flows): + x = flow(x, x_mask, g=g, reverse=reverse) + return x + + def remove_weight_norm(self): + for i in range(self.n_flows): + self.flows[i * 2].remove_weight_norm() + + +class PosteriorEncoder(nn.Module): + def __init__( + self, + in_channels, + out_channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=0, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.enc = modules.WN( + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=gin_channels, + ) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths, g=None): + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to( + x.dtype + ) + x = self.pre(x) * x_mask + x = self.enc(x, x_mask, g=g) + stats = self.proj(x) * x_mask + m, logs = torch.split(stats, self.out_channels, dim=1)#均值和方差 fang + z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask ##随机采样 fang + return z, m, logs, x_mask + + def remove_weight_norm(self): + self.enc.remove_weight_norm() + + +class Generator(torch.nn.Module): + def __init__( + self, + initial_channel, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels=0, + ): + super(Generator, self).__init__() + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + self.conv_pre = Conv1d( + initial_channel, upsample_initial_channel, 7, 1, padding=3 + ) + resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate( + zip(resblock_kernel_sizes, resblock_dilation_sizes) + ): + self.resblocks.append(resblock(ch, k, d)) + + self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) + self.ups.apply(init_weights) + + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) + + def forward(self, x, g=None): + x = self.conv_pre(x) + if g is not None: + x = x + self.cond(g) + + for i in range(self.num_upsamples): + x = F.leaky_relu(x, modules.LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + + +class SineGen(torch.nn.Module): + """Definition of sine generator + SineGen(samp_rate, harmonic_num = 0, + sine_amp = 0.1, noise_std = 0.003, + voiced_threshold = 0, + flag_for_pulse=False) + samp_rate: sampling rate in Hz + harmonic_num: number of harmonic overtones (default 0) + sine_amp: amplitude of sine-wavefrom (default 0.1) + noise_std: std of Gaussian noise (default 0.003) + voiced_thoreshold: F0 threshold for U/V classification (default 0) + flag_for_pulse: this SinGen is used inside PulseGen (default False) + Note: when flag_for_pulse is True, the first time step of a voiced + segment is always sin(np.pi) or cos(0) + """ + + def __init__( + self, + samp_rate, + harmonic_num=0, + sine_amp=0.1, + noise_std=0.003, + voiced_threshold=0, + flag_for_pulse=False, + ): + super(SineGen, self).__init__() + self.sine_amp = sine_amp + self.noise_std = noise_std + self.harmonic_num = harmonic_num + self.dim = self.harmonic_num + 1 + self.sampling_rate = samp_rate + self.voiced_threshold = voiced_threshold + + def _f02uv(self, f0): + # generate uv signal + uv = torch.ones_like(f0) + uv = uv * (f0 > self.voiced_threshold) + return uv + + def forward(self, f0, upp): + """sine_tensor, uv = forward(f0) + input F0: tensor(batchsize=1, length, dim=1) + f0 for unvoiced steps should be 0 + output sine_tensor: tensor(batchsize=1, length, dim) + output uv: tensor(batchsize=1, length, 1) + """ + with torch.no_grad(): + f0 = f0[:, None].transpose(1, 2) + f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device) + # fundamental component + f0_buf[:, :, 0] = f0[:, :, 0] + for idx in np.arange(self.harmonic_num): + f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * ( + idx + 2 + ) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic + rad_values = (f0_buf / self.sampling_rate) % 1 ###%1意味着n_har的乘积无法后处理优化 + rand_ini = torch.rand( + f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device + ) + rand_ini[:, 0] = 0 + rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini + tmp_over_one = torch.cumsum(rad_values, 1) # % 1 #####%1意味着后面的cumsum无法再优化 + tmp_over_one *= upp + tmp_over_one = F.interpolate( + tmp_over_one.transpose(2, 1), + scale_factor=upp, + mode="linear", + align_corners=True, + ).transpose(2, 1) + rad_values = F.interpolate( + rad_values.transpose(2, 1), scale_factor=upp, mode="nearest" + ).transpose( + 2, 1 + ) ####### + tmp_over_one %= 1 + tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0 + cumsum_shift = torch.zeros_like(rad_values) + cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 + sine_waves = torch.sin( + torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi + ) + sine_waves = sine_waves * self.sine_amp + uv = self._f02uv(f0) + uv = F.interpolate( + uv.transpose(2, 1), scale_factor=upp, mode="nearest" + ).transpose(2, 1) + noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 + noise = noise_amp * torch.randn_like(sine_waves) + sine_waves = sine_waves * uv + noise + return sine_waves, uv, noise + + +class SourceModuleHnNSF(torch.nn.Module): + """SourceModule for hn-nsf + SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0) + sampling_rate: sampling_rate in Hz + harmonic_num: number of harmonic above F0 (default: 0) + sine_amp: amplitude of sine source signal (default: 0.1) + add_noise_std: std of additive Gaussian noise (default: 0.003) + note that amplitude of noise in unvoiced is decided + by sine_amp + voiced_threshold: threhold to set U/V given F0 (default: 0) + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + uv (batchsize, length, 1) + """ + + def __init__( + self, + sampling_rate, + harmonic_num=0, + sine_amp=0.1, + add_noise_std=0.003, + voiced_threshod=0, + is_half=True, + ): + super(SourceModuleHnNSF, self).__init__() + + self.sine_amp = sine_amp + self.noise_std = add_noise_std + self.is_half = is_half + # to produce sine waveforms + self.l_sin_gen = SineGen( + sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod + ) + + # to merge source harmonics into a single excitation + self.l_linear = torch.nn.Linear(harmonic_num + 1, 1) + self.l_tanh = torch.nn.Tanh() + + def forward(self, x, upp=None): + sine_wavs, uv, _ = self.l_sin_gen(x, upp) + if self.is_half: + sine_wavs = sine_wavs.half() + sine_merge = self.l_tanh(self.l_linear(sine_wavs)) + return sine_merge, None, None # noise, uv + + +class GeneratorNSF(torch.nn.Module): + def __init__( + self, + initial_channel, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels, + sr, + is_half=False, + ): + super(GeneratorNSF, self).__init__() + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + + self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates)) + self.m_source = SourceModuleHnNSF( + sampling_rate=sr, harmonic_num=0, is_half=is_half + ) + self.noise_convs = nn.ModuleList() + self.conv_pre = Conv1d( + initial_channel, upsample_initial_channel, 7, 1, padding=3 + ) + resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2 + + self.ups = nn.ModuleList() + self.ups_g = nn.ModuleList()# fang add + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + c_cur = upsample_initial_channel // (2 ** (i + 1)) + self.ups.append( + weight_norm( + ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + self.ups_g.append( + nn.Conv1d(upsample_initial_channel,upsample_initial_channel // (2 ** (i + 1) ), 1) + #F.interpolate(input, scale_factor=2, mode='nearest') + )# fang add + if i + 1 < len(upsample_rates): + stride_f0 = np.prod(upsample_rates[i + 1 :]) + self.noise_convs.append( + Conv1d( + 1, + c_cur, + kernel_size=stride_f0 * 2, + stride=stride_f0, + padding=stride_f0 // 2, + ) + ) + else: + self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1)) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate( + zip(resblock_kernel_sizes, resblock_dilation_sizes) + ): + self.resblocks.append(resblock(ch, k, d)) + + self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) + self.ups.apply(init_weights) + + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) + + self.upp = np.prod(upsample_rates) + + def forward(self, x, f0, g=None): + har_source, noi_source, uv = self.m_source(f0, self.upp) + har_source = har_source.transpose(1, 2) + x = self.conv_pre(x) + if g is not None: + #x = x + self.cond(g) ##org + tmp_g = self.cond(g) ##fang add + x = x + tmp_g ##fang add + #print('###@@@@##x:',x.shape ) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, modules.LRELU_SLOPE) + x = self.ups[i](x) + x_source = self.noise_convs[i](har_source) + x = x + x_source + xg = self.ups_g[i](tmp_g) #fang add + x = x + xg #fang add + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + #print('@@@@##x:',x.shape) + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + return x + + def remove_weight_norm(self): + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + + +sr2sr = { + "32k": 32000, + "40k": 40000, + "48k": 48000, + "24k": 24000, +} + + +class SynthesizerTrnMs256NSFsid(nn.Module): + def __init__( + self, + spec_channels, + segment_size, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + spk_embed_dim, + gin_channels, + sr, + **kwargs + ): + super().__init__() + if type(sr) == type("strr"): + sr = sr2sr[sr] + self.spec_channels = spec_channels + self.inter_channels = inter_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.resblock = resblock + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.upsample_rates = upsample_rates + self.upsample_initial_channel = upsample_initial_channel + self.upsample_kernel_sizes = upsample_kernel_sizes + self.segment_size = segment_size + self.gin_channels = gin_channels + # self.hop_length = hop_length# + self.spk_embed_dim = spk_embed_dim + self.enc_p = TextEncoder256( + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + ) + self.dec = GeneratorNSF( + inter_channels, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels=gin_channels, + sr=sr, + is_half=kwargs["is_half"], + ) + self.enc_q = PosteriorEncoder( + spec_channels, + inter_channels, + hidden_channels, + 5, + 1, + 16, + gin_channels=gin_channels, + ) + self.flow = ResidualCouplingBlock( + inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels + ) + self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels) + print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim) + + def remove_weight_norm(self): + self.dec.remove_weight_norm() + self.flow.remove_weight_norm() + self.enc_q.remove_weight_norm() + + def forward( + self, phone, phone_lengths, pitch, pitchf, y, y_lengths, ds + ): # 这里ds是id,[bs,1] + # print(1,pitch.shape)#[bs,t] + g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的 + #print("@@@pitch.shape: ",pitch.shape) + #g = ds.unsqueeze(-1) + m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths) + z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g) + z_p = self.flow(z, y_mask, g=g) + z_slice, ids_slice = commons.rand_slice_segments( + z, y_lengths, self.segment_size + ) #按照self.segment_size这个长度,进行随机切割z,长度固定,开始位置不同存在ids_slice中,z_slice是切割的结果, fang + # print(-1,pitchf.shape,ids_slice,self.segment_size,self.hop_length,self.segment_size//self.hop_length) + pitchf = commons.slice_segments2(pitchf, ids_slice, self.segment_size) + # print(-2,pitchf.shape,z_slice.shape) + o = self.dec(z_slice, pitchf, g=g) + return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q) + + def infer(self, phone, phone_lengths, pitch, nsff0, sid, rate=None): + g = self.emb_g(sid).unsqueeze(-1) + m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths) + z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask + if rate: + head = int(z_p.shape[2] * rate) + z_p = z_p[:, :, -head:] + x_mask = x_mask[:, :, -head:] + nsff0 = nsff0[:, -head:] + z = self.flow(z_p, x_mask, g=g, reverse=True) + print('z shape: ',z.shape) + print('x_mask shape: ',x_mask.shape) + z_x_mask = z * x_mask + print('z_x_mask shape: ',z_x_mask.shape) + print('nsff0 shape:p', nsff0.shape) + print('g shape: ',g.shape) + o = self.dec(z * x_mask, nsff0, g=g) + + self.get_floats() + return o, x_mask, (z, z_p, m_p, logs_p) + + def get_floats(self,): + T = 21.4 #郭宇_但愿人长久_40k.wav + z = torch.randn(1,192 ,2740)# 2s data(同时用2s数据验证,整数倍就对了,防止干扰) + x_mask = torch.randn(1,1 ,2740) + g = torch.randn(1,256 ,1) + + inputs_bfcc = z #z * x_mask + nsff0 = torch.randn(1, 2740) + devices = 'cuda' #'cpu' + self.dec = self.dec.to(devices).half() + inputs_bfcc , nsff0, g = inputs_bfcc.to(devices).half(), nsff0.to(devices).half(), g.to(devices).half() + flops, params = profile(self.dec, (inputs_bfcc, nsff0, g)) + print(f'@@@hifi-gan nsf decflops: {flops/(T*pow(10,9))} GFLOPS, params: { params/pow(10,6)} M') + return 0 + +class SynthesizerTrnMs768NSFsid(nn.Module): + def __init__( + self, + spec_channels, + segment_size, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + spk_embed_dim, + gin_channels, + sr, + **kwargs + ): + super().__init__() + if type(sr) == type("strr"): + sr = sr2sr[sr] + self.spec_channels = spec_channels + self.inter_channels = inter_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.resblock = resblock + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.upsample_rates = upsample_rates + self.upsample_initial_channel = upsample_initial_channel + self.upsample_kernel_sizes = upsample_kernel_sizes + self.segment_size = segment_size + self.gin_channels = gin_channels + # self.hop_length = hop_length# + self.spk_embed_dim = spk_embed_dim + self.enc_p = TextEncoder768( + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + ) + self.dec = GeneratorNSF( + inter_channels, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels=gin_channels, + sr=sr, + is_half=kwargs["is_half"], + ) + self.enc_q = PosteriorEncoder( + spec_channels, + inter_channels, + hidden_channels, + 5, + 1, + 16, + gin_channels=gin_channels, + ) + self.flow = ResidualCouplingBlock( + inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels + ) + #for p in self.flow.parameters(): + # p.requires_grad=False + #for p in self.enc_p.parameters(): + # p.requires_grad=False + + self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels) + print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim) + + self.diff_decoder = diff_decoder + #self.diff_cond_g = nn.Conv1d(256,192, 1) + self.diff_cond_gx = self.zero_module(self.conv_nd(1, 256, 192, 3, padding=1)) + self.diff_cond_out = self.zero_module(self.conv_nd(1, 192, 192, 3, padding=1)) + self.lzp = 0.1 + + def zero_module(self,module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + def conv_nd(self, dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + def remove_weight_norm(self): + self.dec.remove_weight_norm() + self.flow.remove_weight_norm() + self.enc_q.remove_weight_norm() + + def forward( + self, phone, phone_lengths, pitch, pitchf, y, y_lengths, ds + ): # 这里ds是id,[bs,1] + # print(1,pitch.shape)#[bs,t] + #g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的 + #print("@@@@@fang@@@@@") + g = ds.unsqueeze(-1) + #print("g:",g.size()) + #print("phone_lengths: ",phone_lengths.size()) + #print("pitch: ",pitch.size()) + #m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths) + m_p, logs_p, x_mask, x_embed = self.enc_p(phone, pitch, phone_lengths,g)#fang add + z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)#self.enc_q = PosteriorEncoder ##这里面预测出了随机采样的隐变量z,m_q是均值,logs_q是方差,y_mask是mask的数据 fangi + + z_p = self.flow(z, y_mask, g=g)# z是y_msk的输入 + z_p_sample = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * y_mask + zx = self.flow(z_p_sample, y_mask, g=g, reverse=True) + #print("@@@@@g:",g.shape) + g_z_p = self.diff_cond_gx(g) + #print("@@@@@g_z_p:",g_z_p.shape) + z_res = z - zx + + #print('#######x_embed:',x_embed.shape) + #print('#######z_p_sample:',z_p_sample.shape) + #z_p1 = z_p_sample + g_z_p + z_p1 = x_embed + g_z_p + ###diff st + z_p_diff = z_p1.transpose(1,2) ##b,frames,feat + z_diff = z_res.transpose(1,2) ##b,frames,feat + + diff_loss,_ = self.diff_decoder(z_p_diff, gt_spec=z_diff, infer=False, infer_speedup=ddpm_dp.infer_speedup, method=ddpm_dp.method, use_tqdm=ddpm_dp.use_tqdm) + + #self.diff_decoder = self.diff_decoder.float() + #print("@@@z: ",z.shape) + #b = z_p_diff.shape[0] + t = 200#torch.randint(0, 1000, (b,), device=g.device).long() + z_diff = zx.transpose(1,2) + z_x_diff = self.diff_decoder(z_p_diff, gt_spec=z_diff*self.lzp, infer=True, infer_speedup=ddpm_dp.infer_speedup, method=ddpm_dp.method, k_step=t, use_tqdm=False) + #print("@@@z_x: ",z_x.shape) + z1 = z_x_diff.transpose(1,2) + z1 = self.diff_cond_out(z1) + z_in = (zx + z1) + #z_p = z_p_rec.transpose(1,2) + ##diff en + ##oneflow + #z_p = self.flow(z, y_mask, g=g) + + z_slice, ids_slice = commons.rand_slice_segments( + z_in, y_lengths, self.segment_size + ) + # print(-1,pitchf.shape,ids_slice,self.segment_size,self.hop_length,self.segment_size//self.hop_length) + pitchf = commons.slice_segments2(pitchf, ids_slice, self.segment_size) + # print(-2,pitchf.shape,z_slice.shape) + o = self.dec(z_slice, pitchf, g=g) + return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q),diff_loss + + def infer(self, phone, phone_lengths, pitch, nsff0, sid, rate=None): + #g = self.emb_g(sid).unsqueeze(-1) + g = sid.unsqueeze(-1).unsqueeze(0) + #m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths) #org + m_p, logs_p, x_mask, x_embed = self.enc_p(phone, pitch, phone_lengths,g) #fang add + z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask + if rate: + head = int(z_p.shape[2] * rate) + z_p = z_p[:, :, -head:] + x_mask = x_mask[:, :, -head:] + nsff0 = nsff0[:, -head:] + z = self.flow(z_p, x_mask, g=g, reverse=True) + + g_z_p = self.diff_cond_gx(g) + #z_p1 = z_p + g_z_p + z_p1 = x_embed + g_z_p + #if is_half: + #self.diff_decoder = self.diff_decoder.float() + z_p_diff = z_p1.transpose(1,2).float() ##b,frames,feat + z_diff = z.transpose(1,2) ##b,frames,feat + #print("@@z_p_diff", z_p_diff[0,0,:]) + self.diff_decoder = self.diff_decoder.float() + z_x = self.diff_decoder(z_p_diff, gt_spec=z_diff*self.lzp, infer=True, infer_speedup=ddpm_dp.infer_speedup, method=ddpm_dp.method, k_step=200, use_tqdm=ddpm_dp.use_tqdm) + #print("@@z_x", z_x[0,0,:]) + z1 = z_x.transpose(1,2).half() + z_res = self.diff_cond_out(z1) + z = z + z_res + o = self.dec(z * x_mask, nsff0, g=g) + #self.get_floats() + return o, x_mask, (z, z_p, m_p, logs_p) + + +class SynthesizerTrnMs256NSFsid_nono(nn.Module): + def __init__( + self, + spec_channels, + segment_size, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + spk_embed_dim, + gin_channels, + sr=None, + **kwargs + ): + super().__init__() + self.spec_channels = spec_channels + self.inter_channels = inter_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.resblock = resblock + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.upsample_rates = upsample_rates + self.upsample_initial_channel = upsample_initial_channel + self.upsample_kernel_sizes = upsample_kernel_sizes + self.segment_size = segment_size + self.gin_channels = gin_channels + # self.hop_length = hop_length# + self.spk_embed_dim = spk_embed_dim + self.enc_p = TextEncoder256( + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + f0=False, + ) + self.dec = Generator( + inter_channels, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels=gin_channels, + ) + self.enc_q = PosteriorEncoder( + spec_channels, + inter_channels, + hidden_channels, + 5, + 1, + 16, + gin_channels=gin_channels, + ) + self.flow = ResidualCouplingBlock( + inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels + ) + self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels) + print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim) + + def remove_weight_norm(self): + self.dec.remove_weight_norm() + self.flow.remove_weight_norm() + self.enc_q.remove_weight_norm() + + def forward(self, phone, phone_lengths, y, y_lengths, ds): # 这里ds是id,[bs,1] + g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的 + m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths) + z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g) + z_p = self.flow(z, y_mask, g=g) + z_slice, ids_slice = commons.rand_slice_segments( + z, y_lengths, self.segment_size + ) + o = self.dec(z_slice, g=g) + return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q) + + def infer(self, phone, phone_lengths, sid, rate=None): + g = self.emb_g(sid).unsqueeze(-1) + m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths) + z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask + if rate: + head = int(z_p.shape[2] * rate) + z_p = z_p[:, :, -head:] + x_mask = x_mask[:, :, -head:] + z = self.flow(z_p, x_mask, g=g, reverse=True) + o = self.dec(z * x_mask, g=g) + return o, x_mask, (z, z_p, m_p, logs_p) + + +class SynthesizerTrnMs768NSFsid_nono(nn.Module): + def __init__( + self, + spec_channels, + segment_size, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + spk_embed_dim, + gin_channels, + sr=None, + **kwargs + ): + super().__init__() + self.spec_channels = spec_channels + self.inter_channels = inter_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.resblock = resblock + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.upsample_rates = upsample_rates + self.upsample_initial_channel = upsample_initial_channel + self.upsample_kernel_sizes = upsample_kernel_sizes + self.segment_size = segment_size + self.gin_channels = gin_channels + # self.hop_length = hop_length# + self.spk_embed_dim = spk_embed_dim + self.enc_p = TextEncoder768( + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + f0=False, + ) + self.dec = Generator( + inter_channels, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels=gin_channels, + ) + self.enc_q = PosteriorEncoder( + spec_channels, + inter_channels, + hidden_channels, + 5, + 1, + 16, + gin_channels=gin_channels, + ) + self.flow = ResidualCouplingBlock( + inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels + ) + self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels) + print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim) + + def remove_weight_norm(self): + self.dec.remove_weight_norm() + self.flow.remove_weight_norm() + self.enc_q.remove_weight_norm() + + def forward(self, phone, phone_lengths, y, y_lengths, ds): # 这里ds是id,[bs,1] + #g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的 + g = ds.unsqueeze(-1) + #m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths) #org + m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths,g=g)#fang add + z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g) + z_p = self.flow(z, y_mask, g=g) + z_slice, ids_slice = commons.rand_slice_segments( + z, y_lengths, self.segment_size + ) + o = self.dec(z_slice, g=g) + return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q) + + def infer(self, phone, phone_lengths, sid, rate=None): + #g = self.emb_g(sid).unsqueeze(-1) + g = sid.unsqueeze(-1).unsqueeze(0) + #m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths) + m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths,g=g)#fang add + z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask + if rate: + head = int(z_p.shape[2] * rate) + z_p = z_p[:, :, -head:] + x_mask = x_mask[:, :, -head:] + z = self.flow(z_p, x_mask, g=g, reverse=True) + o = self.dec(z * x_mask, g=g) + return o, x_mask, (z, z_p, m_p, logs_p) + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(MultiPeriodDiscriminator, self).__init__() + periods = [2, 3, 5, 7, 11, 17] + # periods = [3, 5, 7, 11, 17, 23, 37] + + discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] + discs = discs + [ + DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods + ] + self.discriminators = nn.ModuleList(discs) + + def forward(self, y, y_hat): + y_d_rs = [] # + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + # for j in range(len(fmap_r)): + # print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape) + y_d_rs.append(y_d_r) + y_d_gs.append(y_d_g) + fmap_rs.append(fmap_r) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class MultiPeriodDiscriminatorV2(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(MultiPeriodDiscriminatorV2, self).__init__() + # periods = [2, 3, 5, 7, 11, 17] + periods = [2, 3, 5, 7, 11, 17, 23, 37] + + discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] + discs = discs + [ + DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods + ] + self.discriminators = nn.ModuleList(discs) + + def forward(self, y, y_hat): + y_d_rs = [] # + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + # for j in range(len(fmap_r)): + # print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape) + y_d_rs.append(y_d_r) + y_d_gs.append(y_d_g) + fmap_rs.append(fmap_r) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f(Conv1d(1, 16, 15, 1, padding=7)), + norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), + norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ] + ) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, modules.LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + self.use_spectral_norm = use_spectral_norm + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f( + Conv2d( + 1, + 32, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + Conv2d( + 32, + 128, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + Conv2d( + 128, + 512, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + Conv2d( + 512, + 1024, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + Conv2d( + 1024, + 1024, + (kernel_size, 1), + 1, + padding=(get_padding(kernel_size, 1), 0), + ) + ), + ] + ) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, modules.LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap diff --git a/AIMeiSheng/meisheng_env_preparex.py b/AIMeiSheng/meisheng_env_preparex.py index a7bc0db..f0c9854 100644 --- a/AIMeiSheng/meisheng_env_preparex.py +++ b/AIMeiSheng/meisheng_env_preparex.py @@ -1,37 +1,38 @@ import os from AIMeiSheng.docker_demo.common import (gs_svc_model_path, gs_hubert_model_path, gs_embed_model_path, gs_rmvpe_model_path, download2disk) def meisheng_env_prepare(logging, AIMeiSheng_Path='./'): cos_path = "https://av-audit-sync-sg-1256122840.cos.ap-singapore.myqcloud.com/dataset/AIMeiSheng/" rmvpe_model_url = cos_path + "rmvpe.pt" if not os.path.exists(gs_rmvpe_model_path): if not download2disk(rmvpe_model_url, gs_rmvpe_model_path): logging.fatal(f"download rmvpe_model err={rmvpe_model_url}") gs_hubert_model_url = cos_path + "hubert_base.pt" if not os.path.exists(gs_hubert_model_path): if not download2disk(gs_hubert_model_url, gs_hubert_model_path): logging.fatal(f"download hubert_model err={gs_hubert_model_url}") - model_svc = "xusong_v2_org_version_alldata_embed1_enzx_diff_fi_e15_s244110.pth" + #model_svc = "xusong_v2_org_version_alldata_embed1_enzx_diff_fi_e15_s244110.pth" + model_svc = "xusong_v2_org_version_alldata_embed1_enzx_diff_ocean_ctl_enc_e22_s363704.pth" base_dir = os.path.dirname(gs_svc_model_path) os.makedirs(base_dir, exist_ok=True) svc_model_url = cos_path + model_svc if not os.path.exists(gs_svc_model_path): if not download2disk(svc_model_url, gs_svc_model_path): logging.fatal(f"download svc_model err={svc_model_url}") model_embed = "model.pt" base_dir = os.path.dirname(gs_embed_model_path) os.makedirs(base_dir, exist_ok=True) embed_model_url = cos_path + model_embed if not os.path.exists(gs_embed_model_path): if not download2disk(embed_model_url, gs_embed_model_path): logging.fatal(f"download embed_model err={embed_model_url}") if __name__ == "__main__": meisheng_env_prepare() diff --git a/AIMeiSheng/meisheng_svc_final.py b/AIMeiSheng/meisheng_svc_final.py index e6e1ec2..9a5d94f 100644 --- a/AIMeiSheng/meisheng_svc_final.py +++ b/AIMeiSheng/meisheng_svc_final.py @@ -1,224 +1,227 @@ import os import sys sys.path.append(os.path.dirname(__file__)) import time import shutil import glob import hashlib import librosa import soundfile import gradio as gr import pandas as pd import numpy as np from AIMeiSheng.RawNet3.infererence_fang_meisheng import get_embed, get_embed_model from myinfer_multi_spk_embed_in_dec_diff_fi_meisheng import svc_main, load_hubert, get_vc, get_rmvpe from gender_classify import load_gender_model from AIMeiSheng.docker_demo.common import gs_svc_model_path, gs_embed_model_path, gs_rmvpe_model_path, gs_err_code_target_silence +from slicex.slice_set_silence import del_noise gs_simple_mixer_path = "/data/gpu_env_common/bin/simple_mixer" ##混音执行文件 tmp_workspace_name = "batch_test_ocean_fi" # 工作空间名 song_folder = "./data_meisheng/" ##song folder gs_work_dir = f"./data_meisheng/{tmp_workspace_name}" # 工作空间路径 pth_model_path = "./weights/xusong_v2_org_version_alldata_embed1_enzx_diff_fi_e15_s244110.pth" ##模型文件 cur_dir = os.path.abspath(os.path.dirname(__file__)) abs_path = os.path.join(cur_dir, song_folder, tmp_workspace_name) + '/' f0_method = None def mix(in_path, acc_path, dst_path): # svc转码到442 svc_442_file = in_path + "_442.wav" st = time.time() cmd = "ffmpeg -i {} -ar 44100 -ac 2 -y {} -loglevel fatal".format(in_path, svc_442_file) os.system(cmd) if not os.path.exists(svc_442_file): return -1 print("transcode,{},sp={}".format(in_path, time.time() - st)) # 混合 st = time.time() cmd = "{} {} {} {} 1".format(gs_simple_mixer_path, svc_442_file, acc_path, dst_path) os.system(cmd) print("mixer,{},sp={}".format(in_path, time.time() - st)) def load_model(): global f0_method embed_model = get_embed_model(gs_embed_model_path) hubert_model = load_hubert() get_vc(gs_svc_model_path) f0_method = get_rmvpe(gs_rmvpe_model_path) print("model preload finish!!!") return embed_model, hubert_model # ,svc_model def meisheng_init(): embed_model, hubert_model = load_model() ##提前加载模型 gender_model = load_gender_model() return embed_model, hubert_model, gender_model def pyin_process_single_rmvpe(input_file): global f0_method if f0_method is None: f0_method = get_rmvpe() rate = 16000 # 44100 # 读取音频文件 y, sr = librosa.load(input_file, sr=rate) len_s = len(y) / sr lim_s = 15 # 10 if (len_s > lim_s): y1 = y[:sr * lim_s] y2 = y[-sr * lim_s:] f0 = f0_method.infer_from_audio(y1, thred=0.03) f0 = f0[f0 < 600] valid_f0 = f0[f0 > 50] mean_pitch1 = np.mean(valid_f0) f0 = f0_method.infer_from_audio(y2, thred=0.03) f0 = f0[f0 < 600] valid_f0 = f0[f0 > 50] mean_pitch2 = np.mean(valid_f0) if abs(mean_pitch1 - mean_pitch2) > 55: mean_pitch_cur = min(mean_pitch1, mean_pitch2) else: mean_pitch_cur = (mean_pitch1 + mean_pitch2) / 2 else: f0 = f0_method.infer_from_audio(y, thred=0.03) f0 = f0[f0 < 600] valid_f0 = f0[f0 > 50] mean_pitch_cur = np.mean(valid_f0) return mean_pitch_cur def meisheng_svc(song_wav, target_wav, svc_out_path, embed_npy, embed_md, hubert_md, paras): ##计算pitch f0up_key = pyin_process_single_rmvpe(target_wav) if f0up_key < 40 or np.isnan(f0up_key):#unvoice return gs_err_code_target_silence ## get embed, 音色 get_embed(target_wav, embed_npy, embed_md) print("svc main start...") svc_main(song_wav, svc_out_path, embed_npy, f0up_key, hubert_md, paras) print("svc main finished!!") + del_noise(song_wav,svc_out_path) + print("del noise in silence") return 0 def process_svc_online(song_wav, target_wav, svc_out_path, embed_md, hubert_md, paras): embed_npy = target_wav[:-4] + '.npy' ##embd npy存储位置 err_code = meisheng_svc(song_wav, target_wav, svc_out_path, embed_npy, embed_md, hubert_md, paras) return err_code def process_svc(song_wav, target_wav, svc_out_path, embed_md, hubert_md, paras): song_wav1, target_wav, svc_out_path = os.path.basename(song_wav), os.path.basename( target_wav), os.path.basename(svc_out_path) # 绝对路径 song_wav, target_wav, svc_out_path = song_wav, abs_path + target_wav, abs_path + svc_out_path embed_npy = target_wav[:-4] + '.npy' ##embd npy存储位置 # similar = meisheng_svc(song_wav,target_wav,svc_out_path,embed_npy,paras) similar = meisheng_svc(song_wav, target_wav, svc_out_path, embed_npy, embed_md, hubert_md, paras) return similar def get_svc(target_yinse_wav, song_name, embed_model, hubert_model, paras): ''' :param target_yinse_wav: 目标音色 :param song_name: 歌曲名字 ;param paras: 其他参数 :return: svc路径名 ''' ##清空工作空间临时路径 if os.path.exists(gs_work_dir): # shutil.rmtree(gs_work_dir) cmd = f"rm -rf {gs_work_dir}/*" os.system(cmd) else: os.makedirs(gs_work_dir) gender = paras['gender'] ##为了确定歌曲 ##目标音色读取 f_dst = os.path.join(gs_work_dir, os.path.basename(target_yinse_wav)) # print("dir :", f_dst,"target_yinse_wav:",target_yinse_wav) # shutil.move(target_yinse_wav, f_dst) ##放在工作目录 shutil.copy(target_yinse_wav, f_dst) target_yinse_wav = f_dst ##歌曲/伴奏 读取(路径需要修改) song_wav = os.path.join("{}{}/{}/vocal321.wav".format(song_folder, gender, song_name)) # 歌曲vocal inf_acc_path = os.path.join("{}{}/{}/acc.wav".format(song_folder, gender, song_name)) # song_wav = './xusong_long.wav' svc_out_path = os.path.join(gs_work_dir, "svc.wav") ###svc结果名字 print("inputMsg:", song_wav, target_yinse_wav, svc_out_path) ## svc process st = time.time() print("start inference...") similar = process_svc(song_wav, target_yinse_wav, svc_out_path, embed_model, hubert_model, paras) print("svc finished!!") print("time cost = {}".format(time.time() - st)) print("out path name {} ".format(svc_out_path)) # ''' ##加混响 print("add reverbration...") svc_out_path_effect = svc_out_path[:-4] + '_effect.wav' cmd = f"/data/gpu_env_common/bin/effect_tool {svc_out_path} {svc_out_path_effect}" print("cmd :", cmd) os.system(cmd) # # 人声伴奏合并 print("add acc...") out_path = svc_out_path_effect[:-4] + '_music.wav' mix(svc_out_path_effect, inf_acc_path, out_path) print("time cost = {}".format(time.time() - st)) print("out path name {} ".format(out_path)) # ''' return svc_out_path def meisheng_func(target_yinse_wav, song_name, paras): ##init embed_model, hubert_model, gender_model = meisheng_init() ###gender predict gender, female_rate, is_pure = gender_model.process(target_yinse_wav) print('=====================') print("gender:{}, female_rate:{},is_pure:{}".format(gender, female_rate, is_pure)) if gender == 0: gender = 'female' elif gender == 1: gender = 'male' elif female_rate > 0.5: gender = 'female' else: gender = 'male' print("modified gender:{} ".format(gender)) print('=====================') ##美声main paras['gender'] = gender ##单位都是ms get_svc(target_yinse_wav, song_name, embed_model, hubert_model, paras) if __name__ == '__main__': # target_yinse_wav = "./raw/meisheng_yinse/female/changying.wav" # 需要完整路径 target_yinse_wav = "./raw/meisheng_yinse/female/target_yinse_cloris.m4a" song_name = "lost_stars" ##歌曲名字 paras = {'gender': None, 'tst': 0, "tnd": None, 'delay': 0, 'song_path': None} # paras = {'gender': 'female', 'tst': 0, "tnd": 30, 'delay': 0} ###片段svc测试 meisheng_func(target_yinse_wav, song_name, paras) diff --git a/AIMeiSheng/myinfer_multi_spk_embed_in_dec_diff_fi_meisheng.py b/AIMeiSheng/myinfer_multi_spk_embed_in_dec_diff_fi_meisheng.py index f1da5a9..4a60e74 100644 --- a/AIMeiSheng/myinfer_multi_spk_embed_in_dec_diff_fi_meisheng.py +++ b/AIMeiSheng/myinfer_multi_spk_embed_in_dec_diff_fi_meisheng.py @@ -1,217 +1,217 @@ import os,sys,pdb,torch now_dir = os.getcwd() sys.path.append(now_dir) import argparse import glob import sys import torch from multiprocessing import cpu_count class Config: def __init__(self,device,is_half): self.device = device self.is_half = is_half self.n_cpu = 0 self.gpu_name = None self.gpu_mem = None self.x_pad, self.x_query, self.x_center, self.x_max = self.device_config() def device_config(self) -> tuple: current_dir = os.path.dirname(os.path.abspath(__file__)) config_path = os.path.join(current_dir, "configs") if torch.cuda.is_available(): i_device = int(self.device.split(":")[-1]) self.gpu_name = torch.cuda.get_device_name(i_device) if ( ("16" in self.gpu_name and "V100" not in self.gpu_name.upper()) or "P40" in self.gpu_name.upper() or "1060" in self.gpu_name or "1070" in self.gpu_name or "1080" in self.gpu_name ): print("16系/10系显卡和P40强制单精度") self.is_half = False for config_file in ["32k.json", "40k.json", "48k.json"]: with open(f"{config_path}/{config_file}", "r") as f: strr = f.read().replace("true", "false") with open(f"{config_path}/{config_file}", "w") as f: f.write(strr) with open(f"{current_dir}/trainset_preprocess_pipeline_print.py", "r") as f: strr = f.read().replace("3.7", "3.0") with open(f"{current_dir}/trainset_preprocess_pipeline_print.py", "w") as f: f.write(strr) else: self.gpu_name = None self.gpu_mem = int( torch.cuda.get_device_properties(i_device).total_memory / 1024 / 1024 / 1024 + 0.4 ) if self.gpu_mem <= 4: with open(f"{current_dir}/trainset_preprocess_pipeline_print.py", "r") as f: strr = f.read().replace("3.7", "3.0") with open(f"{current_dir}/trainset_preprocess_pipeline_print.py", "w") as f: f.write(strr) elif torch.backends.mps.is_available(): print("没有发现支持的N卡, 使用MPS进行推理") self.device = "mps" else: print("没有发现支持的N卡, 使用CPU进行推理") self.device = "cpu" self.is_half = True if self.n_cpu == 0: self.n_cpu = cpu_count() if self.is_half: # 6G显存配置 x_pad = 3 x_query = 10 x_center = 80 #60 x_max = 85#65 else: # 5G显存配置 x_pad = 1 x_query = 6 x_center = 38 x_max = 41 if self.gpu_mem != None and self.gpu_mem <= 4: x_pad = 1 x_query = 5 x_center = 30 x_max = 32 return x_pad, x_query, x_center, x_max index_path="./logs/xusong_v2_org_version_multispk_charlie_puth_embed_in_dec_muloss_show/added_IVF614_Flat_nprobe_1_xusong_v2_org_version_multispk_charlie_puth_embed_in_dec_show_v2.index" # f0method="rmvpe" #harvest or pm index_rate=float("0.0") #index rate device="cuda:0" is_half=True filter_radius=int(3) ##3 resample_sr=int(0) # 0 rms_mix_rate=float(1) # rms混合比例 1,不等于1混合 protect=float(0.33 )## ??? 0.33 fang #print(sys.argv) config=Config(device,is_half) now_dir=os.getcwd() sys.path.append(now_dir) from vc_infer_pipeline_org_embed import VC -from lib.infer_pack.models_embed_in_dec_diff_fi import ( +from lib.infer_pack.models_embed_in_dec_diff_control_enc import ( SynthesizerTrnMs256NSFsid, SynthesizerTrnMs256NSFsid_nono, SynthesizerTrnMs768NSFsid, SynthesizerTrnMs768NSFsid_nono, ) from lib.audio import load_audio from fairseq import checkpoint_utils from scipy.io import wavfile from AIMeiSheng.docker_demo.common import gs_hubert_model_path # hubert_model=None def load_hubert(): # global hubert_model models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task([gs_hubert_model_path],suffix="",) #models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(["checkpoint_best_legacy_500.pt"],suffix="",) hubert_model = models[0] hubert_model = hubert_model.to(device) if(is_half):hubert_model = hubert_model.half() else:hubert_model = hubert_model.float() hubert_model.eval() return hubert_model def vc_single(sid,input_audio,f0_up_key,f0_file,f0_method,file_index,index_rate,hubert_model,paras): global tgt_sr,net_g,vc,version if input_audio is None:return "You need to upload an audio", None f0_up_key = int(f0_up_key) # print("@@xxxf0_up_key:",f0_up_key) audio = load_audio(input_audio,16000) if paras != None: st = int(paras['tst'] * 16000/1000) en = len(audio) if paras['tnd'] != None: en = min(en,int(paras['tnd'] * 16000/1000)) audio = audio[st:en] times = [0, 0, 0] if(hubert_model==None): hubert_model = load_hubert() if_f0 = cpt.get("f0", 1) audio_opt=vc.pipeline_mulprocess(hubert_model,net_g,sid,audio,input_audio,times,f0_up_key,f0_method,file_index,index_rate,if_f0,filter_radius,tgt_sr,resample_sr,rms_mix_rate,version,protect,f0_file=f0_file) #print(times) #print("@@using multi process") return audio_opt def get_vc_core(model_path,is_half): #print("loading pth %s" % model_path) cpt = torch.load(model_path, map_location="cpu") tgt_sr = cpt["config"][-1] cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] if_f0 = cpt.get("f0", 1) version = cpt.get("version", "v1") if version == "v1": if if_f0 == 1: net_g = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=is_half) else: net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"]) elif version == "v2": if if_f0 == 1: # net_g = SynthesizerTrnMs768NSFsid(*cpt["config"], is_half=is_half) else: net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"]) #print("load model finished") del net_g.enc_q net_g.load_state_dict(cpt["weight"], strict=False) #print("load net_g finished") return tgt_sr,net_g,cpt,version def get_vc1(model_path,is_half): tgt_sr, net_g, cpt, version = get_vc_core(model_path, is_half) net_g.eval().to(device) if (is_half):net_g = net_g.half() else:net_g = net_g.float() vc = VC(tgt_sr, config) n_spk=cpt["config"][-3] return def get_rmvpe(model_path="rmvpe.pt"): from lib.rmvpe import RMVPE global f0_method #print("loading rmvpe model") f0_method = RMVPE(model_path, is_half=True, device='cuda') return f0_method def get_vc(model_path): global n_spk,tgt_sr,net_g,vc,cpt,device,is_half,version tgt_sr, net_g, cpt, version = get_vc_core(model_path, is_half) net_g.eval().to(device) if (is_half):net_g = net_g.half() else:net_g = net_g.float() vc = VC(tgt_sr, config) n_spk=cpt["config"][-3] # return {"visible": True,"maximum": n_spk, "__type__": "update"} # return net_g def svc_main(input_path,opt_path,sid_embed,f0up_key=0,hubert_model=None, paras=None): #print("sid_embed: ",sid_embed) wav_opt = vc_single(sid_embed,input_path,f0up_key,None,f0_method,index_path,index_rate,hubert_model,paras) #print("out_path: ",opt_path) wavfile.write(opt_path, tgt_sr, wav_opt) diff --git a/AIMeiSheng/slicex/slice_set_silence.py b/AIMeiSheng/slicex/slice_set_silence.py new file mode 100644 index 0000000..f1b51b6 --- /dev/null +++ b/AIMeiSheng/slicex/slice_set_silence.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- + + +import librosa # Optional. Use any library you like to read audio files. +import soundfile # Optional. Use any library you like to write audio files. +from slicex.slicer_torch import Slicer + + +class silce_silence(): + def __init__(self, sr): + # audio = torch.from_numpy(audio) + self.slicer = Slicer( + sr=sr, + threshold=-40, + min_length=5000, + min_interval=300, + hop_size=10, + max_sil_kept=500 + ) + + def set_silence(self,chunks,sr, target_audio, target_sr): + ''' + :param chunks: slice结果 of song wav + :param sr: song in sr + :param target_audio: svc_out + :param target_sr: svc_out sr + :return: + ''' + # target_audio = np.zeros(int(len(audio)*target_sr/sr),1) + # result = [] + for k, v in chunks.items(): + tag = v["split_time"].split(",") + # if tag[0] != tag[1]: + # result.append((v["slice"], audio[int(tag[0]):int(tag[1])])) + + if( tag[0] != tag[1] and v["slice"] == True):#静音 + st = int(int(tag[0])*target_sr/sr) + en = min(int(int(tag[1])*target_sr/sr), len(target_audio)) + target_audio[st:en] = 0#0.001 * target_audio[st:en] + return target_audio + + def cut(self, audio): + chunks = self.slicer.slice(audio) + chunks = dict(chunks) + return chunks + +def del_noise(wav_in,svc_out): + audio, sr = librosa.load(wav_in, sr=None) # Load an audio file with librosa. + target_audio, target_sr = librosa.load(svc_out, sr=None) # Load an audio file with librosa. + + + slice_sil = silce_silence(sr) + chunks = slice_sil.cut(audio) + target_audio1 = slice_sil.set_silence(chunks, sr, target_audio, target_sr) + soundfile.write(svc_out, target_audio1, target_sr) + return + + + diff --git a/AIMeiSheng/slicex/slicer_torch.py b/AIMeiSheng/slicex/slicer_torch.py new file mode 100644 index 0000000..5b33fcc --- /dev/null +++ b/AIMeiSheng/slicex/slicer_torch.py @@ -0,0 +1,118 @@ +import librosa +import torch +#import torchaudio + + +class Slicer: + def __init__(self, + sr: int, + threshold: float = -40., + min_length: int = 5000, + min_interval: int = 300, + hop_size: int = 20, + max_sil_kept: int = 5000): + if not min_length >= min_interval >= hop_size: + raise ValueError('The following condition must be satisfied: min_length >= min_interval >= hop_size') + if not max_sil_kept >= hop_size: + raise ValueError('The following condition must be satisfied: max_sil_kept >= hop_size') + min_interval = sr * min_interval / 1000 + self.threshold = 10 ** (threshold / 20.) + self.hop_size = round(sr * hop_size / 1000) + self.win_size = min(round(min_interval), 4 * self.hop_size) + self.min_length = round(sr * min_length / 1000 / self.hop_size) + self.min_interval = round(min_interval / self.hop_size) + self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size) + + def _apply_slice(self, waveform, begin, end): + if len(waveform.shape) > 1: + return waveform[:, begin * self.hop_size: min(waveform.shape[1], end * self.hop_size)] + else: + return waveform[begin * self.hop_size: min(waveform.shape[0], end * self.hop_size)] + + # @timeit + def slice(self, waveform): + if len(waveform.shape) > 1: + samples = librosa.to_mono(waveform) + else: + samples = waveform + if samples.shape[0] <= self.min_length: + return {"0": {"slice": False, "split_time": f"0,{len(waveform)}"}} + rms_list = librosa.feature.rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0) + sil_tags = [] + silence_start = None + clip_start = 0 + for i, rms in enumerate(rms_list): + # Keep looping while frame is silent. + if rms < self.threshold: + # Record start of silent frames. + if silence_start is None: + silence_start = i + continue + # Keep looping while frame is not silent and silence start has not been recorded. + if silence_start is None: + continue + # Clear recorded silence start if interval is not enough or clip is too short + is_leading_silence = silence_start == 0 and i > self.max_sil_kept + need_slice_middle = i - silence_start >= self.min_interval and i - clip_start >= self.min_length + if not is_leading_silence and not need_slice_middle: + silence_start = None + continue + # Need slicing. Record the range of silent frames to be removed. + if i - silence_start <= self.max_sil_kept: + pos = rms_list[silence_start: i + 1].argmin() + silence_start + if silence_start == 0: + sil_tags.append((0, pos)) + else: + sil_tags.append((pos, pos)) + clip_start = pos + elif i - silence_start <= self.max_sil_kept * 2: + pos = rms_list[i - self.max_sil_kept: silence_start + self.max_sil_kept + 1].argmin() + pos += i - self.max_sil_kept + pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start + pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept + if silence_start == 0: + sil_tags.append((0, pos_r)) + clip_start = pos_r + else: + sil_tags.append((min(pos_l, pos), max(pos_r, pos))) + clip_start = max(pos_r, pos) + else: + pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start + pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept + if silence_start == 0: + sil_tags.append((0, pos_r)) + else: + sil_tags.append((pos_l, pos_r)) + clip_start = pos_r + silence_start = None + # Deal with trailing silence. + total_frames = rms_list.shape[0] + if silence_start is not None and total_frames - silence_start >= self.min_interval: + silence_end = min(total_frames, silence_start + self.max_sil_kept) + pos = rms_list[silence_start: silence_end + 1].argmin() + silence_start + sil_tags.append((pos, total_frames + 1)) + # Apply and return slices. + if len(sil_tags) == 0: + return {"0": {"slice": False, "split_time": f"0,{len(waveform)}"}} + else: + chunks = [] + # 第一段静音并非从头开始,补上有声片段 + if sil_tags[0][0]: + chunks.append( + {"slice": False, "split_time": f"0,{min(waveform.shape[0], sil_tags[0][0] * self.hop_size)}"}) + for i in range(0, len(sil_tags)): + # 标识有声片段(跳过第一段) + if i: + chunks.append({"slice": False, + "split_time": f"{sil_tags[i - 1][1] * self.hop_size},{min(waveform.shape[0], sil_tags[i][0] * self.hop_size)}"}) + # 标识所有静音片段 + chunks.append({"slice": True, + "split_time": f"{sil_tags[i][0] * self.hop_size},{min(waveform.shape[0], sil_tags[i][1] * self.hop_size)}"}) + # 最后一段静音并非结尾,补上结尾片段 + if sil_tags[-1][1] * self.hop_size < len(waveform): + chunks.append({"slice": False, "split_time": f"{sil_tags[-1][1] * self.hop_size},{len(waveform)}"}) + chunk_dict = {} + for i in range(len(chunks)): + chunk_dict[str(i)] = chunks[i] + return chunk_dict +