Page Menu
Home
Phabricator
Search
Configure Global Search
Log In
Files
F4880311
No One
Temporary
Actions
View File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Flag For Later
Size
91 KB
Subscribers
None
View Options
diff --git a/AutoCoverTool/online/common.py b/AutoCoverTool/online/common.py
index 982a1f0..87fe3cf 100644
--- a/AutoCoverTool/online/common.py
+++ b/AutoCoverTool/online/common.py
@@ -1,163 +1,235 @@
# -*-encoding=utf8-*-
+import os
import time
import pymysql
banned_user_map = {
"host": "sg-songbook00.db.starmaker.co",
"user": "worker",
"passwd": "gRYppQtdTpP3nFzH",
"db": "starmaker"
}
gs_songbook_test_banned_user_map = {
"host": "sg-test-server-goapi-1",
"user": "root",
"passwd": "solo2018",
"db": "av_db"
}
banned_user_map_v1 = {
"host": "sg-starmaker-device-r2.db.starmaker.co",
"user": "worker",
"passwd": "gRYppQtdTpP3nFzH",
"db": "mis"
}
banned_user_map_v2 = {
"host": "sg-sm-img-r1.starmaker.co",
"user": "worker",
"passwd": "gRYppQtdTpP3nFzH",
"db": "sm"
}
# 做一下shared库的查询依赖
shard_map = {
"shard_sm_12": "sg-shard02-r2.db.starmaker.co",
"shard_sm_13": "sg-shard02-r2.db.starmaker.co",
"shard_sm_14": "sg-shard02-r2.db.starmaker.co",
"shard_sm_15": "sg-shard02-r2.db.starmaker.co",
"shard_sm_30": "sg-shard02-r2.db.starmaker.co",
"shard_sm_31": "sg-shard02-r2.db.starmaker.co",
"shard_sm_20": "sg-shard02-r2.db.starmaker.co",
"shard_sm_21": "sg-shard02-r2.db.starmaker.co",
"shard_sm_22": "sg-shard03-r2.db.starmaker.co",
"shard_sm_23": "sg-shard03-r2.db.starmaker.co",
"shard_sm_24": "sg-shard03-r2.db.starmaker.co",
"shard_sm_25": "sg-shard03-r2.db.starmaker.co",
"shard_sm_26": "sg-shard03-r2.db.starmaker.co",
"shard_sm_27": "sg-shard03-r2.db.starmaker.co",
"shard_sm_28": "sg-shard03-r2.db.starmaker.co",
"shard_sm_29": "sg-shard03-r2.db.starmaker.co",
"shard_sm_0": "sg-shard00-r2.db.starmaker.co",
"shard_sm_1": "sg-shard00-r2.db.starmaker.co",
"shard_sm_2": "sg-shard00-r2.db.starmaker.co",
"shard_sm_3": "sg-shard00-r2.db.starmaker.co",
"shard_sm_4": "sg-shard00-r2.db.starmaker.co",
"shard_sm_5": "sg-shard00-r2.db.starmaker.co",
"shard_sm_16": "sg-shard00-r2.db.starmaker.co",
"shard_sm_17": "sg-shard00-r2.db.starmaker.co",
"shard_sm_6": "sg-shard01-r2.db.starmaker.co",
"shard_sm_7": "sg-shard01-r2.db.starmaker.co",
"shard_sm_8": "sg-shard01-r2.db.starmaker.co",
"shard_sm_9": "sg-shard01-r2.db.starmaker.co",
"shard_sm_10": "sg-shard01-r2.db.starmaker.co",
"shard_sm_11": "sg-shard01-r2.db.starmaker.co",
"shard_sm_18": "sg-shard01-r2.db.starmaker.co",
"shard_sm_19": "sg-shard01-r2.db.starmaker.co",
"shard_sm_32": "sg-shard04-r2.db.starmaker.co",
"shard_sm_33": "sg-shard04-r2.db.starmaker.co",
"shard_sm_34": "sg-shard04-r2.db.starmaker.co",
"shard_sm_35": "sg-shard04-r2.db.starmaker.co",
"shard_sm_36": "sg-shard04-r2.db.starmaker.co",
"shard_sm_37": "sg-shard04-r2.db.starmaker.co",
"shard_sm_38": "sg-shard04-r2.db.starmaker.co",
"shard_sm_39": "sg-shard04-r2.db.starmaker.co",
- "shard_sm_40": "sg-shard05-r2.db.starmaker.co",
+ # "shard_sm_40": "sg-shard05-r2.db.starmaker.co",
"shard_sm_41": "sg-shard05-r2.db.starmaker.co",
"shard_sm_42": "sg-shard05-r2.db.starmaker.co",
"shard_sm_43": "sg-shard05-r2.db.starmaker.co",
"shard_sm_44": "sg-shard05-r2.db.starmaker.co",
"shard_sm_45": "sg-shard05-r2.db.starmaker.co",
"shard_sm_46": "sg-shard05-r2.db.starmaker.co",
"shard_sm_47": "sg-shard05-r2.db.starmaker.co",
"shard_sm_48": "sg-shard05-r2.db.starmaker.co",
"shard_sm_49": "sg-shard05-r2.db.starmaker.co",
"shard_sm_50": "sg-shard05-r2.db.starmaker.co",
"name": "shard_sm_{}",
"port": 3306,
"user": "readonly",
"passwd": "JKw6woZgRXsveegL"
}
def connect_db(host="research-db-r1.starmaker.co", port=3306, user="root", passwd="Qrdl1130", db=""):
print("connect mysql host={} port={} user={} passwd={} db={}".format(host, port, user, passwd, db))
return pymysql.connect(host=host, port=port, user=user, passwd=passwd, db=db)
def get_data_by_mysql(sql, ban=banned_user_map):
db = connect_db(host=ban["host"], passwd=ban["passwd"], user=ban["user"],
db=ban["db"])
db_cursor = db.cursor()
if len(sql) < 100:
print("execute = {}".format(sql))
else:
print("execute = {}...".format(sql[:100]))
db_cursor.execute(sql)
res = db_cursor.fetchall()
db_cursor.close()
db.close()
print("res size={}".format(len(res)))
return res
def get_shard_db(user_id):
return int(float(user_id)) >> 48
def get_shard_data_by_sql(sql, user_id):
shard_id = get_shard_db(user_id)
db_name = shard_map["name"].format(shard_id)
host = shard_map[db_name]
db = connect_db(host=host, passwd=shard_map["passwd"], user=shard_map["user"], db=db_name)
db_cursor = db.cursor()
if len(sql) < 100:
print("execute = {}".format(sql))
else:
print("execute = {}...".format(sql[:100]))
db_cursor.execute(sql)
res = db_cursor.fetchall()
db_cursor.close()
db.close()
print("res size={}".format(len(res)))
return res
+def get_all_shared_data_by_sql(sql):
+ res = []
+ for i in range(0, 50):
+ db_name = shard_map["name"].format(i)
+ if db_name not in shard_map.keys():
+ continue
+ host = shard_map[db_name]
+ db = connect_db(host=host, passwd=shard_map["passwd"], user=shard_map["user"], db=db_name)
+ db_cursor = db.cursor()
+ if len(sql) < 100:
+ print("execute = {}".format(sql))
+ else:
+ print("execute = {}...".format(sql[:100]))
+
+ db_cursor.execute(sql)
+ cur_data = db_cursor.fetchall()
+ db_cursor.close()
+ db.close()
+ print("res size={}".format(len(cur_data)))
+ res.extend(cur_data)
+ return res
+
+
def read_file(in_file):
with open(in_file, "r") as f:
lines = f.readlines()
return lines
def write2file(file_path, data):
with open(file_path, "w") as f:
for line in data:
line += "\n"
f.write(line)
def update_db(sql, ban=banned_user_map):
db = connect_db(host=ban["host"], passwd=ban["passwd"], user=ban["user"],
db=ban["db"])
db_cursor = db.cursor()
if len(sql) < 100:
print("execute = {}".format(sql))
else:
print("execute = {}...".format(sql[:100]))
db_cursor.execute(sql)
db.commit()
db_cursor.close()
db.close()
+
+
+def exec_cmd(cmd):
+ ret = os.system(cmd)
+ if ret != 0:
+ return False
+ return True
+
+
+def exec_cmd_and_result(cmd):
+ r = os.popen(cmd)
+ text = r.read()
+ r.close()
+ return text
+
+
+def check_file(key, region, bucket_name, cos_cmd, conf_path):
+ cmd = "{} -c {} -r {} -b {} info {}".format(cos_cmd, conf_path, region, bucket_name, key) \
+ + "| grep Content-Length |awk \'{print $2}\'"
+ res_str = exec_cmd_and_result(cmd)
+ size = float(res_str)
+ if size > 0:
+ return True
+ return False
+
+
+def upload_file2cos(key, file_path, region, bucket_name, cos_cmd, conf_path):
+ """
+ 将文件上传到cos
+ :param key: 桶上的具体地址
+ :param file_path: 本地文件地址
+ :param region: 区域
+ :param bucket_name: 桶地址
+ :param cos_cmd: coscmd地址
+ :param conf_path: coscmd的配置文件
+ :return:
+ """
+ cmd = "{} -c {} -r {} -b {} upload {} {}".format(cos_cmd, conf_path, region, bucket_name, file_path, key)
+ print(cmd)
+ if exec_cmd(cmd):
+ cmd = "{} -c {} -r {} -b {} info {}".format(cos_cmd, conf_path, region, bucket_name, key) \
+ + "| grep Content-Length |awk \'{print $2}\'"
+ res_str = exec_cmd_and_result(cmd)
+ size = float(res_str)
+ if size > 0:
+ return True
+ return False
+ return False
diff --git a/AutoCoverTool/online/train_user_get_data.py b/AutoCoverTool/online/train_user_get_data.py
new file mode 100644
index 0000000..77278d2
--- /dev/null
+++ b/AutoCoverTool/online/train_user_get_data.py
@@ -0,0 +1,18 @@
+from online.train_users_model import GenerateUserData
+
+"""
+成对的数据,其中本代码用于从服务器获取数据,train_users_model用于在北京训练模型
+操作分为三部分:
+1. python3 train_user_get_data.py | 生成数据
+2. python3 train_users_model.py
+3. python3 train_users_download_model.py | 下载数据
+"""
+
+if __name__ == '__main__':
+ gud_inst = GenerateUserData()
+ # gud_inst.get_user_ids_from_db()
+
+ arr = [
+ ["1741807722", "av_res/so_vits_models/3.0_v1/1741807722_1.pth", 2]
+ ]
+ gud_inst.download_model_and_update_gender(arr)
diff --git a/AutoCoverTool/online/train_users_model.py b/AutoCoverTool/online/train_users_model.py
new file mode 100644
index 0000000..be90fd2
--- /dev/null
+++ b/AutoCoverTool/online/train_users_model.py
@@ -0,0 +1,315 @@
+"""
+训练人声模型
+输入: 训练使用的音频(干声文件)
+输出: 训练出的模型(训练1000轮次)、测试音频的生成结果
+"""
+import os
+import time
+import json
+import glob
+import shutil
+import GPUtil
+from script.train_user_by_one_media import SoVitsSVCOnlineTrain
+from ref.online.voice_class_online import VoiceClass
+from online.beanstalk_helper import BeanstalkHelper
+from online.common import update_db, get_data_by_mysql, upload_file2cos, exec_cmd, get_all_shared_data_by_sql
+
+# cos资源目录: av-audit-sync-bj-1256122840
+gs_cos_dir = "av_res/so_vits_models/train_users/{user_id}" # 内部结构: record_id.mp4[干声],test_svc_file.m4a
+gs_record_path = os.path.join(gs_cos_dir, "{record_id}.mp4")
+gs_svc_path = os.path.join(gs_cos_dir, "example.m4a")
+gs_model_path = os.path.join("av_res/so_vits_models/3.0_v1/", "{user_id}_{gender}.pth")
+
+gs_bean_config = {"addr": "sg-test-common-box-1:11300", "consumer": "auto_user_svc_trainer_v1"}
+# gs_coscmd = "/data/gpu_env_common/env/anaconda3/envs/auto_song_cover/bin/coscmd"
+
+
+gs_coscmd = "/bin/coscmd"
+gs_config_path = "/home/normal/.cos.conf"
+if GPUtil.getGPUs()[0].name == "Tesla T4":
+ gs_coscmd = "/data/gpu_env_common/env/anaconda3/envs/auto_song_cover_t4/bin/coscmd"
+ gs_config_path = "/data/gpu_env_common/bin/cos.conf"
+
+gs_model_dir = "/data/gpu_env_common/res/av_svc/models"
+gs_cache_dir = "/tmp/train_users_model_cache_dir"
+gs_download_cache_dir = "/tmp/train_users_model_download_cache_dir"
+gs_region = "ap-beijing"
+gs_bucket_name = "av-audit-sync-bj-1256122840"
+
+gs_tum_err_success = 0
+gs_tum_err_already_processed = 1
+gs_tum_err_download_origin = 2
+gs_tum_err_train_and_inf = 3
+gs_tum_err_test_media_transcode = 4
+gs_tum_err_gender_err = 5
+
+
+def get_d(audio_path):
+ cmd = "ffprobe -v quiet -print_format json -show_format -show_streams {}".format(audio_path)
+ data = exec_cmd(cmd)
+ data = json.loads(data)
+ if "format" in data.keys():
+ if "duration" in data['format']:
+ return float(data["format"]["duration"])
+ return 0
+
+
+def download_url(url, dst_path):
+ if url != "":
+ cmd = "wget {} -O {}".format(url, dst_path)
+ os.system(cmd)
+ return os.path.exists(dst_path)
+
+
+def update_gender(user_id, gender):
+ """
+ 查看数据库,只有当性别是3[未知]再更新
+ :return:
+ """
+ sql = "select * from av_db.av_svc_model where user_id=\"{}\" and gender=3".format(user_id)
+ data = get_data_by_mysql(sql)
+ if len(data) == 1:
+ sql = "update av_db.av_svc_model set gender={} where user_id=\"{}\"".format(gender, user_id)
+ update_db(sql)
+
+
+class StatHelper:
+ """
+ 状态更新类
+ """
+
+ def __init__(self, user_id, record_id):
+ self.user_id = user_id
+ self.record_id = record_id
+
+ def is_processed(self):
+ """
+ 库里有数据,并且不是-1,1,2,3意味着可以处理
+ 处理之后更新为3,后面人工听过没问题,设置为1/2
+ :return:
+ """
+ sql = "select * from av_db.av_svc_model where user_id = \"{}\" and gender not in (-1, 1, 2, 3)".format(
+ self.user_id)
+ data = get_data_by_mysql(sql)
+ return len(data) > 0
+
+ def update_model_url(self, local_model_path, gender, test_svc_m4a):
+ """
+ 上传到cos上,并更新数据库
+ :param local_model_path:
+ :param gender:
+ :param test_svc_m4a:
+ :return:
+ """
+ if not self.is_processed():
+ # 先上传到cos,然后再更新数据库
+ key = gs_model_path.format(user_id=self.user_id, gender=gender)
+ key_svc_m4a = gs_svc_path.format(user_id=self.user_id)
+
+ if upload_file2cos(key, local_model_path, gs_region, gs_bucket_name, gs_coscmd,
+ gs_config_path) \
+ and upload_file2cos(key_svc_m4a, test_svc_m4a, gs_region, gs_bucket_name, gs_coscmd,
+ gs_config_path):
+ sql = "insert into av_db.av_svc_model (model_version, user_id, model_url, gender) values (\"sovits3.0_v1\", \"{}\", \"{}\", 3)".format(
+ self.user_id, key)
+ update_db(sql)
+
+
+class TrainUserModelOnline:
+ def __init__(self):
+ g_path = os.path.join(gs_model_dir, 'sunyanzi_base_2000.pth')
+ d_path = os.path.join(gs_model_dir, 'sunyanzi_base_d_2000.pth')
+ self.test_svc_src_file = os.path.join(gs_model_dir, "../syz_test.wav")
+ self.ssot_inst = SoVitsSVCOnlineTrain(g_path, d_path)
+ self.cache_dir = None
+
+ def download_file(self, user_id, record_id):
+ filepath = gs_record_path.format(user_id=user_id, record_id=record_id)
+ local_path = os.path.join(self.cache_dir, "{}.mp4".format(record_id))
+ cmd = "{} -c {} -r {} -b {} download {} {}".format(gs_coscmd, gs_config_path, gs_region, gs_bucket_name,
+ filepath, local_path)
+ exec_cmd(cmd)
+ return os.path.exists(local_path)
+
+ def process_one(self, user_id, record_id, gender):
+ """
+ 1. 检查是否处理过
+ 2. 下载数据
+ 3. 训练并生成测试干声
+ 4. 上传并更新数据库
+ :param user_id:
+ :param record_id:
+ :param gender:
+ :return:
+ """
+ sh_inst = StatHelper(user_id, record_id)
+ # 检查文件是否处理过
+ if sh_inst.is_processed():
+ print("train_user_model,user_id={},record_id={}, already_processed".format(user_id, record_id))
+ return gs_tum_err_already_processed
+
+ # 下载干声文件
+ if not self.download_file(user_id, record_id):
+ print("train_user_model,user_id={},record_id={}, download_file err".format(user_id, record_id))
+ return gs_tum_err_download_origin
+
+ origin_path = os.path.join(self.cache_dir, "{}.mp4".format(record_id))
+
+ # 训练并生成干声
+ dst_path = os.path.join(self.cache_dir, "example.wav")
+ dst_m4a_path = os.path.join(self.cache_dir, "example.m4a")
+ dst_model_path = os.path.join(self.cache_dir, "{}_{}.pth".format(user_id, gender))
+ ret = self.ssot_inst.process_train_and_infer(origin_path, self.test_svc_src_file, dst_path, dst_model_path,
+ {"max_step": 1000})
+ if ret != 0:
+ print("train_user_model,user_id={},record_id={}, process_train_and_infer err={}".format(user_id, record_id,
+ ret))
+ return gs_tum_err_train_and_inf
+
+ # 对音频做转码
+ cmd = "ffmpeg -i {} -ar 32000 -ac 1 {}".format(dst_path, dst_m4a_path)
+ print(cmd)
+ os.system(cmd)
+ if not os.path.exists(dst_m4a_path):
+ print("train_user_model,user_id={},record_id={}, transcode".format(user_id, record_id))
+ return gs_tum_err_test_media_transcode
+ sh_inst.update_model_url(dst_model_path, gender, dst_m4a_path)
+ return gs_tum_err_success
+
+ def process(self):
+ bean_helper = BeanstalkHelper(gs_bean_config)
+ bean = bean_helper.get_beanstalkd()
+ bean.watch(gs_bean_config["consumer"])
+ if not os.path.exists(gs_cache_dir):
+ os.makedirs(gs_cache_dir)
+ while True:
+ payload = bean.reserve(5)
+ print(payload)
+ if not payload:
+ print("bean sleep...")
+ continue
+ in_data = json.loads(payload.body)
+ user_id = in_data["user_id"] # 包括user_id/record_id.mp4
+ record_id = in_data["record_id"]
+ gender = in_data["gender"]
+ self.cache_dir = os.path.join(gs_cache_dir, "{}".format(user_id))
+ if os.path.exists(self.cache_dir):
+ shutil.rmtree(self.cache_dir)
+ os.makedirs(self.cache_dir)
+ ret = "exp"
+ try:
+ ret = self.process_one(user_id, record_id, gender)
+ except Exception as ex:
+ print("train_user_model,user_id={},record_id={}, ex={}".format(user_id, record_id, ex))
+ print("train_user_model,user_id={},record_id={}, ret={}".format(user_id, record_id, ret))
+ payload.delete()
+ if os.path.exists(self.cache_dir):
+ shutil.rmtree(self.cache_dir)
+ continue
+
+
+class GenerateUserData:
+ """
+ 从数据库筛选出合适的数据
+ 数据要求:
+ 1. 干声
+ 2. 检查出来是纯人声
+ 3. 上传到北京的cos
+ """
+
+ def __init__(self):
+ music_voice_pure_model = os.path.join(gs_model_dir, "voice_005_rec_v5.pth")
+ music_voice_no_pure_model = os.path.join(gs_model_dir, "voice_10_v5.pth")
+ gender_pure_model = os.path.join(gs_model_dir, "gender_8k_ratev5_v6_adam.pth")
+ gender_no_pure_model = os.path.join(gs_model_dir, "gender_8k_v6_adam.pth")
+ self.voice_class_inst = VoiceClass(music_voice_pure_model, music_voice_no_pure_model, gender_pure_model,
+ gender_no_pure_model)
+
+ def process_one(self, origin_path):
+ # 获取性别信息
+ gender_map = [2, 1, 3]
+ gender, rate, is_pure = self.voice_class_inst.process(origin_path) # 0女,1男,2未知 | 对外要求: 1男2女3未知
+ if not is_pure or rate == -1 or gender not in (0, 1, 2):
+ print("train_user_model,gender check, res={}".format([gender, rate, is_pure]))
+ return gs_tum_err_gender_err
+ gender = gender_map[gender]
+ return gender
+
+ def get_user_ids_from_db(self):
+ """
+ 从数据库获取数据
+ :return:
+ """
+ bean_helper = BeanstalkHelper(gs_bean_config)
+ cur_time = time.time() - 3600 * 4
+ sql = """
+ select id, user_id,recording_url
+ from recording
+ where created_on > {cur_time}
+ and grade in ('A++', 'A+')
+ and is_public = 1 and is_deleted = 0 and media_type in (1, 2, 3, 4, 9, 10)
+ """.format(cur_time=cur_time)
+ data = get_all_shared_data_by_sql(sql)
+ if not os.path.exists(gs_download_cache_dir):
+ os.makedirs(gs_download_cache_dir)
+
+ user_in_bean_cnt = 0
+ user_cnt = 0
+
+ user_ids_in_bean = []
+ st = time.time()
+ for record_id, user_id, record_url in data:
+ if user_id in user_ids_in_bean:
+ continue
+
+ origin_path = os.path.join(gs_download_cache_dir, "{}_{}.mp4".format(user_id, record_id))
+ record_url = str(record_url).replace("master.mp4", "origin_master.mp4")
+ if download_url(record_url, origin_path):
+ gender = self.process_one(origin_path)
+ if gender in (0, 1, 2):
+ # 将文件上传到cos
+ key = gs_record_path.format(user_id=user_id, record_id=record_id)
+ if upload_file2cos(key, origin_path, gs_region, gs_bucket_name, gs_coscmd, gs_config_path):
+ message = json.dumps(
+ {"user_id": str(user_id), "record_id": str(record_id), "gender": int(gender)})
+ bean_helper.put_payload_to_beanstalk(gs_bean_config["consumer"], message, ttr=2 * 86400)
+ user_ids_in_bean.append(user_id)
+ print("put msg to db={}".format(message))
+ user_in_bean_cnt += 1
+ continue
+ if user_cnt % 1000 == 0:
+ print("{}/{}/{} sp={}".format(user_in_bean_cnt, user_cnt, len(data), time.time() - st))
+
+ user_cnt += 1
+ print("finish {}/{}/{} sp={}".format(user_in_bean_cnt, user_cnt, len(data), time.time() - st))
+
+ def download_model_and_update_gender(self, arr):
+ st = time.time()
+ for idx, msg in enumerate(arr):
+ user_id, model_url, gender = msg
+ model_name = str(model_url).split("/")[-1]
+ local_path = os.path.join("/data/prod/so_vits_models/3.0/unknown/{}".format(model_name))
+ cmd = "{} -c {} -r {} -b {} download {} -f {}".format(gs_coscmd, gs_config_path, gs_region, gs_bucket_name,
+ model_url, local_path)
+ print(cmd)
+ exec_cmd(cmd)
+ if idx % 10 == 0:
+ print("{}/{} sp={}".format(idx, len(arr), time.time() - st))
+
+ # 更新数据
+ update_gender(user_id, gender)
+ print("finish {} sp={}".format(len(arr), time.time() - st))
+
+
+def put_data2bean():
+ bean_helper = BeanstalkHelper(gs_bean_config)
+ message = json.dumps({"user_id": "3096224751941488", "record_id": "3096224802670844"})
+ bean_helper.put_payload_to_beanstalk(gs_bean_config["consumer"], message, ttr=2 * 86400)
+ print("input_finish....")
+
+
+if __name__ == '__main__':
+ tumo_inst = TrainUserModelOnline()
+ tumo_inst.process()
+ # gud_inst = GenerateUserData()
+ # gud_inst.get_user_ids_from_db()
diff --git a/AutoCoverTool/ref/online/voice_class_online.py b/AutoCoverTool/ref/online/voice_class_online.py
index eab43a7..c101a20 100644
--- a/AutoCoverTool/ref/online/voice_class_online.py
+++ b/AutoCoverTool/ref/online/voice_class_online.py
@@ -1,420 +1,421 @@
"""
男女声分类在线工具
1 转码为16bit单声道
2 均衡化
3 模型分类
"""
import os
import sys
import librosa
import shutil
import logging
import time
import torch.nn.functional as F
import numpy as np
from model import *
# from common import bind_kernel
logging.basicConfig(level=logging.INFO)
os.environ["LRU_CACHE_CAPACITY"] = "1"
# torch.set_num_threads(1)
# bind_kernel(1)
"""
临时用一下,全局使用的变量
"""
transcode_time = 0
vb_time = 0
mfcc_time = 0
predict_time = 0
"""
错误码
"""
ERR_CODE_SUCCESS = 0 # 处理成功
ERR_CODE_NO_FILE = -1 # 文件不存在
ERR_CODE_TRANSCODE = -2 # 转码失败
ERR_CODE_VOLUME_BALANCED = -3 # 均衡化失败
ERR_CODE_FEATURE_TOO_SHORT = -4 # 特征文件太短
"""
常量
"""
FRAME_LEN = 128
MFCC_LEN = 80
EBUR128_BIN = "/data/gpu_env_common/res/av_svc/bin/standard_audio_no_cut"
# EBUR128_BIN = "/Users/yangjianli/linux/opt/soft/bin/standard_audio_no_cut"
GENDER_FEMALE = 0
GENDER_MALE = 1
GENDER_OTHER = 2
"""
通用函数
"""
def exec_cmd(cmd):
ret = os.system(cmd)
if ret != 0:
return False
return True
"""
业务需要的函数
"""
def get_one_mfcc(file_url):
st = time.time()
data, sr = librosa.load(file_url, sr=16000)
if len(data) < 512:
return []
mfcc = librosa.feature.mfcc(y=data, sr=sr, n_fft=512, hop_length=256, n_mfcc=MFCC_LEN)
mfcc = mfcc.transpose()
print("get_one_mfcc:spend_time={}".format(time.time() - st))
global mfcc_time
mfcc_time += time.time() - st
return mfcc
def volume_balanced(src, dst):
st = time.time()
cmd = "{} {} {}".format(EBUR128_BIN, src, dst)
logging.info(cmd)
exec_cmd(cmd)
if not os.path.exists(dst):
logging.error("volume_balanced:cmd={}".format(cmd))
print("volume_balanced:spend_time={}".format(time.time() - st))
global vb_time
vb_time += time.time() - st
return os.path.exists(dst)
def transcode(src, dst):
st = time.time()
cmd = "ffmpeg -loglevel quiet -i {} -ar 16000 -ac 1 {}".format(src, dst)
logging.info(cmd)
exec_cmd(cmd)
if not os.path.exists(dst):
logging.error("transcode:cmd={}".format(cmd))
print("transcode:spend_time={}".format(time.time() - st))
global transcode_time
transcode_time += time.time() - st
return os.path.exists(dst)
class VoiceClass:
def __init__(self, music_voice_pure_model, music_voice_no_pure_model, gender_pure_model, gender_no_pure_model):
"""
四个模型
:param music_voice_pure_model: 分辨纯净人声/其他
:param music_voice_no_pure_model: 分辨有人声/其他
:param gender_pure_model: 纯净人声分辨男女
:param gender_no_pure_model: 有人声分辨男女
"""
st = time.time()
self.device = "cpu"
self.batch_size = 256
self.music_voice_pure_model = load_model(MusicVoiceV5Model, music_voice_pure_model, self.device)
self.music_voice_no_pure_model = load_model(MusicVoiceV5Model, music_voice_no_pure_model, self.device)
self.gender_pure_model = load_model(MobileNetV2Gender, gender_pure_model, self.device)
self.gender_no_pure_model = load_model(MobileNetV2Gender, gender_no_pure_model, self.device)
logging.info("load model ok ! spend_time={}".format(time.time() - st))
def batch_predict(self, model, features):
st = time.time()
scores = []
with torch.no_grad():
for i in range(0, len(features), self.batch_size):
cur_data = features[i:i + self.batch_size].to(self.device)
predicts = model(cur_data)
predicts_score = F.softmax(predicts, dim=1)
scores.extend(predicts_score.cpu().numpy())
ret = np.array(scores)
global predict_time
predict_time += time.time() - st
return ret
def predict_pure(self, filename, features):
scores = self.batch_predict(self.music_voice_pure_model, features)
new_features = []
for idx, score in enumerate(scores):
if score[0] > 0.5: # 非人声
continue
new_features.append(features[idx].numpy())
# 人声段太少,不能进行处理
# 参数可以改
new_feature_len = len(new_features)
new_feature_rate = len(new_features) / len(features)
if new_feature_len < 4 or new_feature_rate < 0.4:
logging.warning(
"filename={}|predict_pure|other|len={}|rate={}".format(filename, new_feature_len, new_feature_rate)
)
return GENDER_OTHER, -1
new_features = torch.from_numpy(np.array(new_features))
scores = self.batch_predict(self.gender_pure_model, new_features)
f_avg = sum(scores[:, 0]) / len(scores)
m_avg = sum(scores[:, 1]) / len(scores)
female_rate = f_avg / (f_avg + m_avg)
if female_rate > 0.65:
return GENDER_FEMALE, female_rate
if female_rate < 0.12:
return GENDER_MALE, female_rate
logging.warning(
"filename={}|predict_pure|other|len={}|rate={}".format(filename, new_feature_len, new_feature_rate)
)
return GENDER_OTHER, female_rate
def predict_no_pure(self, filename, features):
scores = self.batch_predict(self.music_voice_no_pure_model, features)
new_features = []
for idx, score in enumerate(scores):
if score[0] > 0.5: # 非人声
continue
new_features.append(features[idx].numpy())
# 人声段太少,不能进行处理
# 参数可以改
new_feature_len = len(new_features)
new_feature_rate = len(new_features) / len(features)
if new_feature_len < 4 or new_feature_rate < 0.4:
logging.warning(
"filename={}|predict_no_pure|other|len={}|rate={}".format(filename, new_feature_len, new_feature_rate)
)
return GENDER_OTHER, -1
new_features = torch.from_numpy(np.array(new_features))
scores = self.batch_predict(self.gender_no_pure_model, new_features)
f_avg = sum(scores[:, 0]) / len(scores)
m_avg = sum(scores[:, 1]) / len(scores)
female_rate = f_avg / (f_avg + m_avg)
if female_rate > 0.75:
return GENDER_FEMALE, female_rate
if female_rate < 0.1:
return GENDER_MALE, female_rate
logging.warning(
"filename={}|predict_no_pure|other|len={}|rate={}".format(filename, new_feature_len, new_feature_rate)
)
return GENDER_OTHER, female_rate
def predict(self, filename, features):
st = time.time()
new_features = []
for i in range(FRAME_LEN, len(features), FRAME_LEN):
new_features.append(features[i - FRAME_LEN: i])
new_features = torch.from_numpy(np.array(new_features))
gender, rate = self.predict_pure(filename, new_features)
if gender == GENDER_OTHER:
logging.info("start no pure process...")
- return self.predict_no_pure(filename, new_features)
+ gender, rate = self.predict_no_pure(filename, new_features)
+ return gender, rate, False
print("predict|spend_time={}".format(time.time() - st))
- return gender, rate
+ return gender, rate, True
def process_one_logic(self, filename, file_path, cache_dir):
tmp_wav = os.path.join(cache_dir, "tmp.wav")
tmp_vb_wav = os.path.join(cache_dir, "tmp_vb.wav")
if not transcode(file_path, tmp_wav):
- return ERR_CODE_TRANSCODE
+ return ERR_CODE_TRANSCODE, None, None
if not volume_balanced(tmp_wav, tmp_vb_wav):
- return ERR_CODE_VOLUME_BALANCED
+ return ERR_CODE_VOLUME_BALANCED, None, None
features = get_one_mfcc(tmp_vb_wav)
if len(features) < FRAME_LEN:
logging.error("feature too short|file_path={}".format(file_path))
- return ERR_CODE_FEATURE_TOO_SHORT
+ return ERR_CODE_FEATURE_TOO_SHORT, None, None
return self.predict(filename, features)
def process_one(self, file_path):
base_dir = os.path.dirname(file_path)
filename = os.path.splitext(file_path)[0]
cache_dir = os.path.join(base_dir, filename + "_cache")
if os.path.exists(cache_dir):
shutil.rmtree(cache_dir)
os.makedirs(cache_dir)
ret = self.process_one_logic(filename, file_path, cache_dir)
shutil.rmtree(cache_dir)
return ret
def process(self, file_path):
- gender, female_rate = self.process_one(file_path)
+ gender, female_rate, is_pure = self.process_one(file_path)
logging.info("{}|gender={}|female_rate={}".format(file_path, gender, female_rate))
- return gender, female_rate
+ return gender, female_rate, is_pure
def process_by_feature(self, feature_file):
"""
直接处理特征文件
:param feature_file:
:return:
"""
filename = os.path.splitext(feature_file)[0]
features = np.load(feature_file)
gender, female_rate = self.predict(filename, features)
return gender, female_rate
def test_all_feature():
import glob
base_dir = "/data/datasets/music_voice_dataset_full/feature_online_data_v3"
female = glob.glob(os.path.join(base_dir, "female/*feature.npy"))
male = glob.glob(os.path.join(base_dir, "male/*feature.npy"))
other = glob.glob(os.path.join(base_dir, "other/*feature.npy"))
model_path = "/data/jianli.yang/voice_classification/online/models"
music_voice_pure_model = os.path.join(model_path, "voice_005_rec_v5.pth")
music_voice_no_pure_model = os.path.join(model_path, "voice_10_v5.pth")
gender_pure_model = os.path.join(model_path, "gender_8k_ratev5_v6_adam.pth")
gender_no_pure_model = os.path.join(model_path, "gender_8k_v6_adam.pth")
vc = VoiceClass(music_voice_pure_model, music_voice_no_pure_model, gender_pure_model, gender_no_pure_model)
tot_st = time.time()
ret_map = {
0: {0: 0, 1: 0, 2: 0},
1: {0: 0, 1: 0, 2: 0},
2: {0: 0, 1: 0, 2: 0}
}
for file in female:
st = time.time()
print("------------------------------>>>>>")
gender, female_score = vc.process_by_feature(file)
ret_map[0][gender] += 1
if gender != 0:
print("err:female->{}|{}|{}".format(gender, file, female_score))
print("process|spend_tm=={}".format(time.time() - st))
for file in male:
st = time.time()
print("------------------------------>>>>>")
gender, female_score = vc.process_by_feature(file)
ret_map[1][gender] += 1
if gender != 1:
print("err:male->{}|{}|{}".format(gender, file, female_score))
print("process|spend_tm=={}".format(time.time() - st))
for file in other:
st = time.time()
print("------------------------------>>>>>")
gender, female_score = vc.process_by_feature(file)
ret_map[2][gender] += 1
if gender != 2:
print("err:other->{}|{}|{}".format(gender, file, female_score))
print("process|spend_tm=={}".format(time.time() - st))
global transcode_time, vb_time, mfcc_time, predict_time
print("spend_time:tot={}|transcode={}|vb={}|gen_feature={}|predict={}".format(time.time() - tot_st, transcode_time,
vb_time, mfcc_time, predict_time))
f_f = ret_map[0][0]
f_m = ret_map[0][1]
f_o = ret_map[0][2]
m_f = ret_map[1][0]
m_m = ret_map[1][1]
m_o = ret_map[1][2]
o_f = ret_map[2][0]
o_m = ret_map[2][1]
o_o = ret_map[2][2]
print("ff:{},fm:{},fo:{}".format(f_f, f_m, f_o))
print("mm:{},mf:{},mo:{}".format(m_m, m_f, m_o))
print("om:{},of:{},oo:{}".format(o_m, o_f, o_o))
# 女性准确率和召回率
f_acc = f_f / (f_f + m_f + o_f)
f_recall = f_f / (f_f + f_m + f_o)
# 男性准确率和召回率
m_acc = m_m / (m_m + f_m + o_m)
m_recall = m_m / (m_m + m_f + m_o)
print("female: acc={}|recall={}".format(f_acc, f_recall))
print("male: acc={}|recall={}".format(m_acc, m_recall))
def test_all():
import glob
base_dir = "/data/datasets/music_voice_dataset_full/online_data_v3_top200"
female = glob.glob(os.path.join(base_dir, "female/*mp4"))
male = glob.glob(os.path.join(base_dir, "male/*mp4"))
other = glob.glob(os.path.join(base_dir, "other/*mp4"))
model_path = "/data/jianli.yang/voice_classification/online/models"
music_voice_pure_model = os.path.join(model_path, "voice_005_rec_v5.pth")
music_voice_no_pure_model = os.path.join(model_path, "voice_10_v5.pth")
gender_pure_model = os.path.join(model_path, "gender_8k_ratev5_v6_adam.pth")
gender_no_pure_model = os.path.join(model_path, "gender_8k_v6_adam.pth")
vc = VoiceClass(music_voice_pure_model, music_voice_no_pure_model, gender_pure_model, gender_no_pure_model)
tot_st = time.time()
ret_map = {
0: {0: 0, 1: 0, 2: 0},
1: {0: 0, 1: 0, 2: 0},
2: {0: 0, 1: 0, 2: 0}
}
for file in female:
st = time.time()
print("------------------------------>>>>>")
gender, female_score = vc.process(file)
ret_map[0][gender] += 1
if gender != 0:
print("err:female->{}|{}|{}".format(gender, file, female_score))
print("process|spend_tm=={}".format(time.time() - st))
for file in male:
st = time.time()
print("------------------------------>>>>>")
gender, female_score = vc.process(file)
ret_map[1][gender] += 1
if gender != 1:
print("err:male->{}|{}|{}".format(gender, file, female_score))
print("process|spend_tm=={}".format(time.time() - st))
for file in other:
st = time.time()
print("------------------------------>>>>>")
gender, female_score = vc.process(file)
ret_map[2][gender] += 1
if gender != 2:
print("err:other->{}|{}|{}".format(gender, file, female_score))
print("process|spend_tm=={}".format(time.time() - st))
global transcode_time, vb_time, mfcc_time, predict_time
print("spend_time:tot={}|transcode={}|vb={}|gen_feature={}|predict={}".format(time.time() - tot_st, transcode_time,
vb_time, mfcc_time, predict_time))
f_f = ret_map[0][0]
f_m = ret_map[0][1]
f_o = ret_map[0][2]
m_f = ret_map[1][0]
m_m = ret_map[1][1]
m_o = ret_map[1][2]
o_f = ret_map[2][0]
o_m = ret_map[2][1]
o_o = ret_map[2][2]
print("ff:{},fm:{},fo:{}".format(f_f, f_m, f_o))
print("mm:{},mf:{},mo:{}".format(m_m, m_f, m_o))
print("om:{},of:{},oo:{}".format(o_m, o_f, o_o))
# 女性准确率和召回率
f_acc = f_f / (f_f + m_f + o_f)
f_recall = f_f / (f_f + f_m + f_o)
# 男性准确率和召回率
m_acc = m_m / (m_m + f_m + o_m)
m_recall = m_m / (m_m + m_f + m_o)
print("female: acc={}|recall={}".format(f_acc, f_recall))
print("male: acc={}|recall={}".format(m_acc, m_recall))
if __name__ == "__main__":
# test_all()
# test_all_feature()
model_path = sys.argv[1]
voice_path = sys.argv[2]
music_voice_pure_model = os.path.join(model_path, "voice_005_rec_v5.pth")
music_voice_no_pure_model = os.path.join(model_path, "voice_10_v5.pth")
gender_pure_model = os.path.join(model_path, "gender_8k_ratev5_v6_adam.pth")
gender_no_pure_model = os.path.join(model_path, "gender_8k_v6_adam.pth")
vc = VoiceClass(music_voice_pure_model, music_voice_no_pure_model, gender_pure_model, gender_no_pure_model)
for i in range(0, 1):
st = time.time()
print("------------------------------>>>>>")
vc.process(voice_path)
print("process|spend_tm=={}".format(time.time() - st))
diff --git a/AutoCoverTool/ref/so_vits_svc/inference/infer_tool.py b/AutoCoverTool/ref/so_vits_svc/inference/infer_tool.py
index 06a4676..1fad561 100644
--- a/AutoCoverTool/ref/so_vits_svc/inference/infer_tool.py
+++ b/AutoCoverTool/ref/so_vits_svc/inference/infer_tool.py
@@ -1,433 +1,433 @@
import hashlib
import json
import logging
import os
import time
from pathlib import Path
import librosa
import maad
import numpy as np
# import onnxruntime
import parselmouth
import soundfile
import torch
import torchaudio
from hubert import hubert_model
import utils
from models import SynthesizerTrn
import copy
logging.getLogger('matplotlib').setLevel(logging.WARNING)
from mel_processing import spectrogram_torch, spec_to_mel_torch
def get_spec(audio):
audio_norm = audio
print(audio_norm)
spec = spectrogram_torch(audio_norm, 1280, 32000, 320, 1280, center=False)
return spec
def read_temp(file_name):
if not os.path.exists(file_name):
with open(file_name, "w") as f:
f.write(json.dumps({"info": "temp_dict"}))
return {}
else:
try:
with open(file_name, "r") as f:
data = f.read()
data_dict = json.loads(data)
if os.path.getsize(file_name) > 50 * 1024 * 1024:
f_name = file_name.replace("\\", "/").split("/")[-1]
print(f"clean {f_name}")
for wav_hash in list(data_dict.keys()):
if int(time.time()) - int(data_dict[wav_hash]["time"]) > 14 * 24 * 3600:
del data_dict[wav_hash]
except Exception as e:
print(e)
print(f"{file_name} error,auto rebuild file")
data_dict = {"info": "temp_dict"}
return data_dict
def write_temp(file_name, data):
with open(file_name, "w") as f:
f.write(json.dumps(data))
def timeit(func):
def run(*args, **kwargs):
t = time.time()
res = func(*args, **kwargs)
print('executing \'%s\' costed %.3fs' % (func.__name__, time.time() - t))
return res
return run
def format_wav(audio_path):
if Path(audio_path).suffix == '.wav':
return
raw_audio, raw_sample_rate = librosa.load(audio_path, mono=True, sr=None)
soundfile.write(Path(audio_path).with_suffix(".wav"), raw_audio, raw_sample_rate)
def get_end_file(dir_path, end):
file_lists = []
for root, dirs, files in os.walk(dir_path):
files = [f for f in files if f[0] != '.']
dirs[:] = [d for d in dirs if d[0] != '.']
for f_file in files:
if f_file.endswith(end):
file_lists.append(os.path.join(root, f_file).replace("\\", "/"))
return file_lists
def get_md5(content):
return hashlib.new("md5", content).hexdigest()
def resize2d_f0(x, target_len):
source = np.array(x)
source[source < 0.001] = np.nan
target = np.interp(np.arange(0, len(source) * target_len, len(source)) / target_len, np.arange(0, len(source)),
source)
res = np.nan_to_num(target)
return res
def get_f0(x, p_len, f0_up_key=0):
time_step = 160 / 16000 * 1000
f0_min = 50
f0_max = 1100
f0_mel_min = 1127 * np.log(1 + f0_min / 700)
f0_mel_max = 1127 * np.log(1 + f0_max / 700)
f0 = parselmouth.Sound(x, 16000).to_pitch_ac(
time_step=time_step / 1000, voicing_threshold=0.6,
pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array['frequency']
if len(f0) > p_len:
f0 = f0[:p_len]
pad_size = (p_len - len(f0) + 1) // 2
if (pad_size > 0 or p_len - len(f0) - pad_size > 0):
f0 = np.pad(f0, [[pad_size, p_len - len(f0) - pad_size]], mode='constant')
f0 *= pow(2, f0_up_key / 12)
f0_mel = 1127 * np.log(1 + f0 / 700)
f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * 254 / (f0_mel_max - f0_mel_min) + 1
f0_mel[f0_mel <= 1] = 1
f0_mel[f0_mel > 255] = 255
f0_coarse = np.rint(f0_mel).astype(np.int)
return f0_coarse, f0
def clean_pitch(input_pitch):
num_nan = np.sum(input_pitch == 1)
if num_nan / len(input_pitch) > 0.9:
input_pitch[input_pitch != 1] = 1
return input_pitch
def plt_pitch(input_pitch):
input_pitch = input_pitch.astype(float)
input_pitch[input_pitch == 1] = np.nan
return input_pitch
def f0_to_pitch(ff):
f0_pitch = 69 + 12 * np.log2(ff / 440)
return f0_pitch
def fill_a_to_b(a, b):
if len(a) < len(b):
for _ in range(0, len(b) - len(a)):
a.append(a[0])
def mkdir(paths: list):
for path in paths:
if not os.path.exists(path):
os.mkdir(path)
class Svc(object):
def __init__(self, net_g_path, config_path, hubert_path="/data/prod/so_vits_models/models/hubert-soft-0d54a1f4.pt",
onnx=False):
self.onnx = onnx
self.net_g_path = net_g_path
self.hubert_path = hubert_path
self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.net_g_ms = None
self.hps_ms = utils.get_hparams_from_file(config_path)
self.target_sample = self.hps_ms.data.sampling_rate
self.hop_size = self.hps_ms.data.hop_length
self.speakers = {}
for spk, sid in self.hps_ms.spk.items():
self.speakers[sid] = spk
self.spk2id = self.hps_ms.spk
# 加载hubert
self.hubert_soft = hubert_model.hubert_soft(hubert_path)
if torch.cuda.is_available():
self.hubert_soft = self.hubert_soft.cuda()
self.load_model()
def load_model(self):
# 获取模型配置
if self.onnx:
raise NotImplementedError
# self.net_g_ms = SynthesizerTrnForONNX(
# 178,
# self.hps_ms.data.filter_length // 2 + 1,
# self.hps_ms.train.segment_size // self.hps_ms.data.hop_length,
# n_speakers=self.hps_ms.data.n_speakers,
# **self.hps_ms.model)
# _ = utils.load_checkpoint(self.net_g_path, self.net_g_ms, None)
else:
self.net_g_ms = SynthesizerTrn(
self.hps_ms.data.filter_length // 2 + 1,
self.hps_ms.train.segment_size // self.hps_ms.data.hop_length,
- **self.hps_ms.model, no_flow=True, use_v3=True)
+ **self.hps_ms.model, no_flow=False, use_v3=False)
_ = utils.load_checkpoint(self.net_g_path, self.net_g_ms, None)
if "half" in self.net_g_path and torch.cuda.is_available():
_ = self.net_g_ms.half().eval().to(self.dev)
else:
_ = self.net_g_ms.eval().to(self.dev)
def get_units(self, source, sr):
source = source.unsqueeze(0).to(self.dev)
with torch.inference_mode():
start = time.time()
units = self.hubert_soft.units(source)
use_time = time.time() - start
print("hubert use time:{}".format(use_time))
return units
def get_unit_pitch(self, in_path, tran):
source, sr = torchaudio.load(in_path)
source_bak = copy.deepcopy(source)
source = torchaudio.functional.resample(source, sr, 16000)
if len(source.shape) == 2 and source.shape[1] >= 2:
source = torch.mean(source, dim=0).unsqueeze(0)
soft = self.get_units(source, sr).squeeze(0).cpu().numpy()
f0_coarse, f0 = get_f0(source.cpu().numpy()[0], soft.shape[0] * 2, tran)
return soft, f0, source_bak
def infer(self, speaker_id, tran, raw_path, dev=False):
if type(speaker_id) == str:
speaker_id = self.spk2id[speaker_id]
sid = torch.LongTensor([int(speaker_id)]).to(self.dev).unsqueeze(0)
soft, pitch, source = self.get_unit_pitch(raw_path, tran)
f0 = torch.FloatTensor(clean_pitch(pitch)).unsqueeze(0).to(self.dev)
if "half" in self.net_g_path and torch.cuda.is_available():
stn_tst = torch.HalfTensor(soft)
else:
stn_tst = torch.FloatTensor(soft)
# 提取幅度谱
# spec = get_spec(source).to(self.dev)
with torch.no_grad():
x_tst = stn_tst.unsqueeze(0).to(self.dev)
start = time.time()
x_tst = torch.repeat_interleave(x_tst, repeats=2, dim=1).transpose(1, 2)
audio = self.net_g_ms.infer(x_tst, f0=f0, g=sid)[0, 0].data.float()
# audio = self.net_g_ms.infer_v1(x_tst, spec[:, :, :f0.size(-1)], f0=f0, g=sid)[0, 0].data.float()
use_time = time.time() - start
print("vits use time:{}".format(use_time))
return audio, audio.shape[-1]
class SVCRealTimeByBuffer(object):
def __init__(self, net_g_path, config_path, hubert_path="/data/prod/so_vits_models/models/hubert-soft-0d54a1f4.pt"):
self.net_g_path = net_g_path
self.hubert_path = hubert_path
self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.net_g_ms = None
self.hps_ms = utils.get_hparams_from_file(config_path)
self.target_sample = self.hps_ms.data.sampling_rate
self.hop_size = self.hps_ms.data.hop_length
self.speakers = {}
for spk, sid in self.hps_ms.spk.items():
self.speakers[sid] = spk
self.spk2id = self.hps_ms.spk
# 加载hubert
self.hubert_soft = hubert_model.hubert_soft(hubert_path)
if torch.cuda.is_available():
self.hubert_soft = self.hubert_soft.cuda()
self.load_model()
def load_model(self):
self.net_g_ms = SynthesizerTrn(
self.hps_ms.data.filter_length // 2 + 1,
self.hps_ms.train.segment_size // self.hps_ms.data.hop_length,
**self.hps_ms.model, no_flow=True)
# _ = utils.load_checkpoint(self.net_g_path, self.net_g_ms, None)
net_g = torch.load(self.net_g_path, map_location='cpu')
self.net_g_ms.load_state_dict(net_g)
if "half" in self.net_g_path and torch.cuda.is_available():
_ = self.net_g_ms.half().eval().to(self.dev)
else:
_ = self.net_g_ms.eval().to(self.dev)
def get_units(self, source, sr):
source = source.unsqueeze(0).to(self.dev)
print("source_shape===>", source.shape)
with torch.inference_mode():
start = time.time()
units = self.hubert_soft.units(source)
use_time = time.time() - start
print("hubert use time:{}".format(use_time))
return units
def get_unit_pitch(self, source, sr, tran):
source = torchaudio.functional.resample(source, sr, 16000)
if len(source.shape) == 2 and source.shape[1] >= 2:
source = torch.mean(source, dim=0).unsqueeze(0)
soft = self.get_units(source, sr).squeeze(0).cpu().numpy()
f0_coarse, f0 = get_f0(source.cpu().numpy()[0], soft.shape[0] * 2, tran)
return soft, f0
def infer(self, speaker_id, tran, source, sr):
if type(speaker_id) == str:
speaker_id = self.spk2id[speaker_id]
sid = torch.LongTensor([int(speaker_id)]).to(self.dev).unsqueeze(0)
soft, pitch = self.get_unit_pitch(source, sr, tran)
f0 = torch.FloatTensor(clean_pitch(pitch)).unsqueeze(0).to(self.dev)
if "half" in self.net_g_path and torch.cuda.is_available():
stn_tst = torch.HalfTensor(soft)
else:
stn_tst = torch.FloatTensor(soft)
with torch.no_grad():
x_tst = stn_tst.unsqueeze(0).to(self.dev)
start = time.time()
x_tst = torch.repeat_interleave(x_tst, repeats=2, dim=1).transpose(1, 2)
audio = self.net_g_ms.infer(x_tst, f0=f0, g=sid)[0, 0].data.float()
use_time = time.time() - start
print("vits use time:{}".format(use_time))
return audio, audio.shape[-1]
def process(self, vocal_path, dst_path, tran=0):
source, sr = librosa.load(vocal_path, sr=32000, mono=True)
# 按照每秒一次进行处理
out_audio = []
source = torch.tensor(source).to(self.dev)
hop_len = 3840 * 4 # 120ms
length = 640 * 1000
for i in range(0, len(source), length - hop_len):
cur_hop_len = hop_len
input_data = source[i:i + length].unsqueeze(0)
audio, _ = self.infer(0, tran, input_data, sr)
if len(audio) < hop_len:
break
if len(out_audio) > 0:
# 本次开头和前面的末尾做fade
for j in range(hop_len):
out_audio[i+j] = out_audio[i+j] * (1-(j / hop_len)) + audio[j] * (j / hop_len)
else:
cur_hop_len = 0
out_audio.extend(audio[cur_hop_len:])
soundfile.write(dst_path, out_audio, sr, format="wav")
# class SvcONNXInferModel(object):
# def __init__(self, hubert_onnx, vits_onnx, config_path):
# self.config_path = config_path
# self.vits_onnx = vits_onnx
# self.hubert_onnx = hubert_onnx
# self.hubert_onnx_session = onnxruntime.InferenceSession(hubert_onnx, providers=['CUDAExecutionProvider', ])
# self.inspect_onnx(self.hubert_onnx_session)
# self.vits_onnx_session = onnxruntime.InferenceSession(vits_onnx, providers=['CUDAExecutionProvider', ])
# self.inspect_onnx(self.vits_onnx_session)
# self.hps_ms = utils.get_hparams_from_file(self.config_path)
# self.target_sample = self.hps_ms.data.sampling_rate
# self.feature_input = FeatureInput(self.hps_ms.data.sampling_rate, self.hps_ms.data.hop_length)
#
# @staticmethod
# def inspect_onnx(session):
# for i in session.get_inputs():
# print("name:{}\tshape:{}\tdtype:{}".format(i.name, i.shape, i.type))
# for i in session.get_outputs():
# print("name:{}\tshape:{}\tdtype:{}".format(i.name, i.shape, i.type))
#
# def infer(self, speaker_id, tran, raw_path):
# sid = np.array([int(speaker_id)], dtype=np.int64)
# soft, pitch = self.get_unit_pitch(raw_path, tran)
# pitch = np.expand_dims(pitch, axis=0).astype(np.int64)
# stn_tst = soft
# x_tst = np.expand_dims(stn_tst, axis=0)
# x_tst_lengths = np.array([stn_tst.shape[0]], dtype=np.int64)
# # 使用ONNX Runtime进行推理
# start = time.time()
# audio = self.vits_onnx_session.run(output_names=["audio"],
# input_feed={
# "hidden_unit": x_tst,
# "lengths": x_tst_lengths,
# "pitch": pitch,
# "sid": sid,
# })[0][0, 0]
# use_time = time.time() - start
# print("vits_onnx_session.run time:{}".format(use_time))
# audio = torch.from_numpy(audio)
# return audio, audio.shape[-1]
#
# def get_units(self, source, sr):
# source = torchaudio.functional.resample(source, sr, 16000)
# if len(source.shape) == 2 and source.shape[1] >= 2:
# source = torch.mean(source, dim=0).unsqueeze(0)
# source = source.unsqueeze(0)
# # 使用ONNX Runtime进行推理
# start = time.time()
# units = self.hubert_onnx_session.run(output_names=["embed"],
# input_feed={"source": source.numpy()})[0]
# use_time = time.time() - start
# print("hubert_onnx_session.run time:{}".format(use_time))
# return units
#
# def transcribe(self, source, sr, length, transform):
# feature_pit = self.feature_input.compute_f0(source, sr)
# feature_pit = feature_pit * 2 ** (transform / 12)
# feature_pit = resize2d_f0(feature_pit, length)
# coarse_pit = self.feature_input.coarse_f0(feature_pit)
# return coarse_pit
#
# def get_unit_pitch(self, in_path, tran):
# source, sr = torchaudio.load(in_path)
# soft = self.get_units(source, sr).squeeze(0)
# input_pitch = self.transcribe(source.numpy()[0], sr, soft.shape[0], tran)
# return soft, input_pitch
class RealTimeVC:
def __init__(self):
self.last_chunk = None
self.last_o = None
self.chunk_len = 16000 # 区块长度
self.pre_len = 3840 # 交叉淡化长度,640的倍数
"""输入输出都是1维numpy 音频波形数组"""
def process(self, svc_model, speaker_id, f_pitch_change, input_wav_path):
audio, sr = torchaudio.load(input_wav_path)
audio = audio.cpu().numpy()[0]
temp_wav = io.BytesIO()
if self.last_chunk is None:
input_wav_path.seek(0)
audio, sr = svc_model.infer(speaker_id, f_pitch_change, input_wav_path)
audio = audio.cpu().numpy()
self.last_chunk = audio[-self.pre_len:]
self.last_o = audio
return audio[-self.chunk_len:]
else:
audio = np.concatenate([self.last_chunk, audio])
soundfile.write(temp_wav, audio, sr, format="wav")
temp_wav.seek(0)
audio, sr = svc_model.infer(speaker_id, f_pitch_change, temp_wav)
audio = audio.cpu().numpy()
ret = maad.util.crossfade(self.last_o, audio, self.pre_len)
self.last_chunk = audio[-self.pre_len:]
self.last_o = audio
return ret[self.chunk_len:2 * self.chunk_len]
diff --git a/AutoCoverTool/ref/so_vits_svc/utils.py b/AutoCoverTool/ref/so_vits_svc/utils.py
index 6bba348..f6057d1 100644
--- a/AutoCoverTool/ref/so_vits_svc/utils.py
+++ b/AutoCoverTool/ref/so_vits_svc/utils.py
@@ -1,366 +1,374 @@
import os
import glob
import re
import sys
import argparse
import logging
import json
import subprocess
import librosa
import numpy as np
import torchaudio
from scipy.io.wavfile import read
import torch
import torchvision
from torch.nn import functional as F
from commons import sequence_mask
from hubert import hubert_model
MATPLOTLIB_FLAG = False
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
logger = logging
f0_bin = 256
f0_max = 1100.0
f0_min = 50.0
f0_mel_min = 1127 * np.log(1 + f0_min / 700)
f0_mel_max = 1127 * np.log(1 + f0_max / 700)
def f0_to_coarse(f0):
"""
将f0按照Log10的级别进行区分,最后归一化到[1-255] 之间
:param f0:
:return:
"""
is_torch = isinstance(f0, torch.Tensor)
f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * np.log(1 + f0 / 700)
f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * (f0_bin - 2) / (f0_mel_max - f0_mel_min) + 1
f0_mel[f0_mel <= 1] = 1
f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1
# np.rint() 四舍五入取整
f0_coarse = (f0_mel + 0.5).long() if is_torch else np.rint(f0_mel).astype(np.int)
assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, (f0_coarse.max(), f0_coarse.min())
return f0_coarse
def get_hubert_model(rank=None):
hubert_soft = hubert_model.hubert_soft("/data/prod/so_vits_models/models/hubert-soft-0d54a1f4.pt")
if rank is not None:
hubert_soft = hubert_soft.cuda(rank)
return hubert_soft
def get_hubert_content(hmodel, y=None, path=None):
if path is not None:
source, sr = torchaudio.load(path)
source = torchaudio.functional.resample(source, sr, 16000)
if len(source.shape) == 2 and source.shape[1] >= 2:
source = torch.mean(source, dim=0).unsqueeze(0)
else:
source = y
source = source.unsqueeze(0)
with torch.inference_mode():
units = hmodel.units(source)
return units.transpose(1, 2)
def get_content(cmodel, y):
with torch.no_grad():
c = cmodel.extract_features(y.squeeze(1))[0]
c = c.transpose(1, 2)
return c
def transform(mel, height): # 68-92
# r = np.random.random()
# rate = r * 0.3 + 0.85 # 0.85-1.15
# height = int(mel.size(-2) * rate)
tgt = torchvision.transforms.functional.resize(mel, (height, mel.size(-1)))
if height >= mel.size(-2):
return tgt[:, :mel.size(-2), :]
else:
silence = tgt[:, -1:, :].repeat(1, mel.size(-2) - height, 1)
silence += torch.randn_like(silence) / 10
return torch.cat((tgt, silence), 1)
def stretch(mel, width): # 0.5-2
return torchvision.transforms.functional.resize(mel, (mel.size(-2), width))
def load_checkpoint(checkpoint_path, model, optimizer=None):
assert os.path.isfile(checkpoint_path)
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
+ # 判定是否是只有model
+ if 'model' not in checkpoint_dict.keys():
+ if hasattr(model, 'module'):
+ model.module.load_state_dict(checkpoint_dict)
+ else:
+ model.load_state_dict(checkpoint_dict)
+ return model, None, 1, 0.0002
+
iteration = checkpoint_dict.get('iteration', None)
learning_rate = checkpoint_dict.get('learning_rate', None)
if iteration is None:
iteration = 1
if learning_rate is None:
learning_rate = 0.0002
if optimizer is not None and checkpoint_dict.get('optimizer', None) is not None:
optimizer.load_state_dict(checkpoint_dict['optimizer'])
saved_state_dict = checkpoint_dict['model']
if hasattr(model, 'module'):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
new_state_dict = {}
for k, v in state_dict.items():
try:
new_state_dict[k] = saved_state_dict[k]
except:
logger.info("%s is not in the checkpoint" % k)
new_state_dict[k] = v
if hasattr(model, 'module'):
model.module.load_state_dict(new_state_dict)
else:
model.load_state_dict(new_state_dict)
logger.info("Loaded checkpoint '{}' (iteration {})".format(
checkpoint_path, iteration))
return model, optimizer, learning_rate, iteration
def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
logger.info("Saving model and optimizer state at iteration {} to {}".format(
iteration, checkpoint_path))
if hasattr(model, 'module'):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
torch.save({'model': state_dict,
'iteration': iteration,
'optimizer': optimizer.state_dict(),
'learning_rate': learning_rate}, checkpoint_path)
clean_ckpt = False
if clean_ckpt:
clean_checkpoints(path_to_models='logs/32k/', n_ckpts_to_keep=3, sort_by_time=True)
def clean_checkpoints(path_to_models='logs/48k/', n_ckpts_to_keep=2, sort_by_time=True):
"""Freeing up space by deleting saved ckpts
Arguments:
path_to_models -- Path to the model directory
n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth
sort_by_time -- True -> chronologically delete ckpts
False -> lexicographically delete ckpts
"""
ckpts_files = [f for f in os.listdir(path_to_models) if os.path.isfile(os.path.join(path_to_models, f))]
name_key = (lambda _f: int(re.compile('._(\d+)\.pth').match(_f).group(1)))
time_key = (lambda _f: os.path.getmtime(os.path.join(path_to_models, _f)))
sort_key = time_key if sort_by_time else name_key
x_sorted = lambda _x: sorted([f for f in ckpts_files if f.startswith(_x) and not f.endswith('_0.pth')],
key=sort_key)
to_del = [os.path.join(path_to_models, fn) for fn in
(x_sorted('G')[:-n_ckpts_to_keep] + x_sorted('D')[:-n_ckpts_to_keep])]
del_info = lambda fn: logger.info(f".. Free up space by deleting ckpt {fn}")
del_routine = lambda x: [os.remove(x), del_info(x)]
rs = [del_routine(fn) for fn in to_del]
def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050):
for k, v in scalars.items():
writer.add_scalar(k, v, global_step)
for k, v in histograms.items():
writer.add_histogram(k, v, global_step)
for k, v in images.items():
writer.add_image(k, v, global_step, dataformats='HWC')
for k, v in audios.items():
writer.add_audio(k, v, global_step, audio_sampling_rate)
def latest_checkpoint_path(dir_path, regex="G_*.pth"):
f_list = glob.glob(os.path.join(dir_path, regex))
f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
x = f_list[-1]
print(x)
return x
def plot_spectrogram_to_numpy(spectrogram):
global MATPLOTLIB_FLAG
if not MATPLOTLIB_FLAG:
import matplotlib
matplotlib.use("Agg")
MATPLOTLIB_FLAG = True
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.WARNING)
import matplotlib.pylab as plt
import numpy as np
fig, ax = plt.subplots(figsize=(10, 2))
im = ax.imshow(spectrogram, aspect="auto", origin="lower",
interpolation='none')
plt.colorbar(im, ax=ax)
plt.xlabel("Frames")
plt.ylabel("Channels")
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data
def plot_alignment_to_numpy(alignment, info=None):
global MATPLOTLIB_FLAG
if not MATPLOTLIB_FLAG:
import matplotlib
matplotlib.use("Agg")
MATPLOTLIB_FLAG = True
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.WARNING)
import matplotlib.pylab as plt
import numpy as np
fig, ax = plt.subplots(figsize=(6, 4))
im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower',
interpolation='none')
fig.colorbar(im, ax=ax)
xlabel = 'Decoder timestep'
if info is not None:
xlabel += '\n\n' + info
plt.xlabel(xlabel)
plt.ylabel('Encoder timestep')
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data
def load_wav_to_torch(full_path):
sampling_rate, data = read(full_path)
return torch.FloatTensor(data.astype(np.float32)), sampling_rate
def load_filepaths_and_text(filename, split="|"):
with open(filename, encoding='utf-8') as f:
filepaths_and_text = [line.strip().split(split) for line in f]
return filepaths_and_text
def get_hparams(init=True):
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, default="./configs/base.json",
help='JSON file for configuration')
parser.add_argument('-m', '--model', type=str, required=True,
help='Model name')
parser.add_argument('-l', '--logs', type=str, required=True,
help='log Name')
args = parser.parse_args()
model_dir = os.path.join(args.logs, args.model)
if not os.path.exists(model_dir):
os.makedirs(model_dir)
config_path = args.config
config_save_path = os.path.join(model_dir, "config.json")
if init:
with open(config_path, "r") as f:
data = f.read()
with open(config_save_path, "w") as f:
f.write(data)
else:
with open(config_save_path, "r") as f:
data = f.read()
config = json.loads(data)
hparams = HParams(**config)
hparams.model_dir = model_dir
return hparams
def get_hparams_from_dir(model_dir):
config_save_path = os.path.join(model_dir, "config.json")
with open(config_save_path, "r") as f:
data = f.read()
config = json.loads(data)
hparams = HParams(**config)
hparams.model_dir = model_dir
return hparams
def get_hparams_from_file(config_path):
with open(config_path, "r") as f:
data = f.read()
config = json.loads(data)
hparams = HParams(**config)
return hparams
def check_git_hash(model_dir):
source_dir = os.path.dirname(os.path.realpath(__file__))
if not os.path.exists(os.path.join(source_dir, ".git")):
logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format(
source_dir
))
return
cur_hash = subprocess.getoutput("git rev-parse HEAD")
path = os.path.join(model_dir, "githash")
if os.path.exists(path):
saved_hash = open(path).read()
if saved_hash != cur_hash:
logger.warn("git hash values are different. {}(saved) != {}(current)".format(
saved_hash[:8], cur_hash[:8]))
else:
open(path, "w").write(cur_hash)
def get_logger(model_dir, filename="train.log"):
global logger
logger = logging.getLogger(os.path.basename(model_dir))
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
if not os.path.exists(model_dir):
os.makedirs(model_dir)
h = logging.FileHandler(os.path.join(model_dir, filename))
h.setLevel(logging.DEBUG)
h.setFormatter(formatter)
logger.addHandler(h)
return logger
class HParams():
def __init__(self, **kwargs):
for k, v in kwargs.items():
if type(v) == dict:
v = HParams(**v)
self[k] = v
def keys(self):
return self.__dict__.keys()
def items(self):
return self.__dict__.items()
def values(self):
return self.__dict__.values()
def __len__(self):
return len(self.__dict__)
def __getitem__(self, key):
return getattr(self, key)
def __setitem__(self, key, value):
return setattr(self, key, value)
def __contains__(self, key):
return key in self.__dict__
def __repr__(self):
return self.__dict__.__repr__()
diff --git a/AutoCoverTool/script/train_user_by_one_media.py b/AutoCoverTool/script/train_user_by_one_media.py
index e01176d..e733f67 100644
--- a/AutoCoverTool/script/train_user_by_one_media.py
+++ b/AutoCoverTool/script/train_user_by_one_media.py
@@ -1,531 +1,547 @@
"""
使用一句话进行人声训练
1. 数据集
2. 训练
"""
from ref.so_vits_svc.models import SynthesizerTrn, MultiPeriodDiscriminator
from ref.so_vits_svc.mel_processing import spectrogram_torch, spec_to_mel_torch, mel_spectrogram_torch
import ref.so_vits_svc.utils as utils
import ref.so_vits_svc.commons as commons
from ref.so_vits_svc.losses import kl_loss, generator_loss, discriminator_loss, feature_loss
import logging
logging.getLogger('numba').setLevel(logging.WARNING)
import os
import time
import torch
import random
import librosa
import soundfile
import torchaudio
import parselmouth
import numpy as np
from tqdm import tqdm
from scipy.io.wavfile import read
from pyworld import pyworld
from copy import deepcopy
import torch.utils.data
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
+gs_denoise_exe = "/data/gpu_env_common/bin/denoise_exe"
+
gs_hmodel = utils.get_hubert_model(0 if torch.cuda.is_available() else None)
gs_model_config = {
"inter_channels": 192,
"hidden_channels": 192,
"filter_channels": 768,
"n_heads": 2,
"n_layers": 6,
"kernel_size": 3,
"p_dropout": 0.1,
"resblock": "1",
"resblock_kernel_sizes": [3, 7, 11],
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
"upsample_rates": [10, 8, 2, 2],
"upsample_initial_channel": 512,
"upsample_kernel_sizes": [16, 16, 4, 4],
"n_layers_q": 3,
"use_spectral_norm": False,
"gin_channels": 256,
"ssl_dim": 256,
"n_speakers": 2
}
gs_train_config = {
"log_interval": 1,
"eval_interval": 1000,
"seed": 1234,
"epochs": 1000,
"learning_rate": 0.0001,
"betas": [
0.8,
0.99
],
"eps": 1e-09,
"batch_size": 12,
"fp16_run": False,
"lr_decay": 0.999875,
"segment_size": 17920,
"init_lr_ratio": 1,
"warmup_epochs": 0,
"c_mel": 1.0, # 45
"c_kl": 1.0,
"c_fm": 1.0,
"c_gen": 1.0,
"use_sr": True,
"max_speclen": 384
}
gs_data_config = {
"max_wav_value": 32768.0,
"sampling_rate": 32000,
"filter_length": 1280,
"hop_length": 320,
"win_length": 1280,
"n_mel_channels": 80,
"mel_fmin": 0.0,
"mel_fmax": None
}
def get_f0(x, p_len, f0_up_key=0):
time_step = 160 / 16000 * 1000
f0_min = 50
f0_max = 1100
f0_mel_min = 1127 * np.log(1 + f0_min / 700)
f0_mel_max = 1127 * np.log(1 + f0_max / 700)
f0 = parselmouth.Sound(x, 16000).to_pitch_ac(
time_step=time_step / 1000, voicing_threshold=0.6,
pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array['frequency']
if len(f0) > p_len:
f0 = f0[:p_len]
pad_size = (p_len - len(f0) + 1) // 2
if (pad_size > 0 or p_len - len(f0) - pad_size > 0):
f0 = np.pad(f0, [[pad_size, p_len - len(f0) - pad_size]], mode='constant')
f0 *= pow(2, f0_up_key / 12)
f0_mel = 1127 * np.log(1 + f0 / 700)
f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * 254 / (f0_mel_max - f0_mel_min) + 1
f0_mel[f0_mel <= 1] = 1
f0_mel[f0_mel > 255] = 255
f0_coarse = np.rint(f0_mel).astype(np.int)
return f0_coarse, f0
def resize2d(x, target_len):
source = np.array(x)
source[source < 0.001] = np.nan
target = np.interp(np.arange(0, len(source) * target_len, len(source)) / target_len, np.arange(0, len(source)),
source)
res = np.nan_to_num(target)
return res
def compute_f0(x, sr, c_len):
# x, sr = librosa.load(path, sr=32000)
f0, t = pyworld.dio(
x.astype(np.double),
fs=sr,
f0_ceil=800,
frame_period=1000 * 320 / sr,
)
f0 = pyworld.stonemask(x.astype(np.double), f0, t, 32000)
for index, pitch in enumerate(f0):
f0[index] = round(pitch, 1)
assert abs(c_len - x.shape[0] // 320) < 3, (c_len, f0.shape)
return None, resize2d(f0, c_len)
def process(filename):
hmodel = utils.get_hubert_model(0 if torch.cuda.is_available() else None)
save_name = filename + ".soft.pt"
if not os.path.exists(save_name):
devive = torch.device("cuda" if torch.cuda.is_available() else "cpu")
wav, _ = librosa.load(filename, sr=16000)
wav = torch.from_numpy(wav).unsqueeze(0).to(devive)
c = utils.get_hubert_content(hmodel, wav)
torch.save(c.cpu(), save_name)
else:
c = torch.load(save_name)
f0path = filename + ".f0.npy"
if not os.path.exists(f0path):
cf0, f0 = compute_f0(filename, c.shape[-1] * 2)
np.save(f0path, f0)
def clean_pitch(input_pitch):
num_nan = np.sum(input_pitch == 1)
if num_nan / len(input_pitch) > 0.9:
input_pitch[input_pitch != 1] = 1
return input_pitch
class TextAudioSpeakerLoader(torch.utils.data.Dataset):
"""
1) loads audio, speaker_id, text pairs
2) normalizes text and converts them to sequences of integers
3) computes spectrograms from audio files.
"""
def __init__(self, audio_path):
self.audio_path = audio_path
self.max_wav_value = gs_data_config['max_wav_value']
self.sampling_rate = gs_data_config['sampling_rate']
self.filter_length = gs_data_config['filter_length']
self.hop_length = gs_data_config['hop_length']
self.win_length = gs_data_config['win_length']
self.use_sr = gs_train_config['use_sr']
self.spec_len = gs_train_config['max_speclen']
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.hmodel = gs_hmodel
random.seed(1234)
self.audio_data = self.get_audio(audio_path)
def get_audio(self, filename):
# 原始音频32k单声道
# 这里存在疑惑:
# audio, sr = librosa.load(filename, sr=self.sampling_rate, mono=True)
sr, audio = read(filename)
audio = torch.FloatTensor(audio.astype(np.float32))
audio_norm = audio / self.max_wav_value
audio_norm = torch.tensor(audio_norm)
audio_norm = audio_norm.unsqueeze(0)
# 幅度谱 帧长1280(40ms),帧移320(10ms),shape为(641, frame_num)
spec = spectrogram_torch(audio_norm, self.filter_length,
self.sampling_rate, self.hop_length, self.win_length,
center=False)
# print(torch.mean(spec))
spec = torch.squeeze(spec, 0)
spk = torch.LongTensor([0])
# # 提取hubert特征,shape为(256, frame_num // 2),后面做补齐
wav = librosa.resample(audio.numpy(), sr, 16000)
wav = torch.from_numpy(wav).unsqueeze(0).to(self.device)
c = utils.get_hubert_content(self.hmodel, wav).squeeze(0)
# 提取f0特征,shape为(frame_num)
cf0, f0 = compute_f0(audio.numpy(), sr, c.shape[-1] * 2)
f0 = torch.FloatTensor(f0)
c = torch.repeat_interleave(c, repeats=2, dim=1) # shape=(256, frame_num)
lmin = min(c.size(-1), spec.size(-1), f0.shape[0])
# 当assert的前面的条件不成立的时候,会报错,并给出后面的信息
assert abs(c.size(-1) - spec.size(-1)) < 4, (c.size(-1), spec.size(-1), f0.shape, filename)
assert abs(lmin - spec.size(-1)) < 4, (c.size(-1), spec.size(-1), f0.shape)
assert abs(lmin - c.size(-1)) < 4, (c.size(-1), spec.size(-1), f0.shape)
spec, c, f0 = spec[:, :lmin], c[:, :lmin], f0[:lmin]
audio_norm = audio_norm[:, :lmin * self.hop_length]
_spec, _c, _audio_norm, _f0 = spec, c, audio_norm, f0
# 取幅度谱特征,hubert特征、f0信息
while spec.size(-1) < self.spec_len:
spec = torch.cat((spec, _spec), -1)
c = torch.cat((c, _c), -1)
f0 = torch.cat((f0, _f0), -1)
audio_norm = torch.cat((audio_norm, _audio_norm), -1)
# hubert特征,f0,幅度谱特征,对应音频段波形,人声编码
return c, f0, spec, audio_norm, spk
def random_one(self):
c, f0, spec, audio_norm, spk = self.audio_data
start = random.randint(0, spec.size(-1) - self.spec_len)
end = start + self.spec_len
spec = spec[:, start:end]
c = c[:, start:end]
f0 = f0[start:end]
audio_norm = audio_norm[:, start * self.hop_length:end * self.hop_length]
return c, f0, spec, audio_norm, spk
def __getitem__(self, index):
- return self.random_one()
+ c, f0, spec, audio_norm, spk = self.random_one()
+ # 没有人声的段,不要
+ cnt = 0
+ while torch.mean(torch.abs(audio_norm)) < 0.02 and cnt < 3:
+ c, f0, spec, audio_norm, spk = self.random_one()
+ cnt += 1
+ return c, f0, spec, audio_norm, spk
def __len__(self):
return 1
class SoVitsSVCOnlineTrain:
def construct_model(self):
net_g = SynthesizerTrn(
gs_data_config["filter_length"] // 2 + 1,
gs_train_config["segment_size"] // gs_data_config["hop_length"],
**gs_model_config,
no_flow=False,
use_v3=False).cuda()
net_d = MultiPeriodDiscriminator(gs_model_config['use_spectral_norm']).cuda()
optim_g = torch.optim.AdamW(
net_g.parameters(),
0.0001,
betas=[0.8, 0.99],
eps=1e-09)
optim_d = torch.optim.AdamW(
net_d.parameters(),
0.0001,
betas=[0.8, 0.99],
eps=1e-09)
# checkpoint_dict = torch.load(base_g_model, map_location='cuda')
net_g.load_state_dict(self.g_model_dict)
net_d.load_state_dict(self.d_model_dict)
optim_g.load_state_dict(self.g_opt_dict)
optim_d.load_state_dict(self.d_opt_dict)
# 设置初始学习率
optim_g.param_groups[0]['lr'] = 2e-4
optim_d.param_groups[0]['lr'] = 2e-4
return net_g, net_d, optim_g, optim_d
def __init__(self, base_g_model, base_d_model):
st1 = time.time()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint_dict = torch.load(base_g_model, map_location='cpu')
self.g_model_dict = checkpoint_dict["model"]
self.g_opt_dict = checkpoint_dict["optimizer"]
checkpoint_dict = torch.load(base_d_model, map_location='cpu')
self.d_model_dict = checkpoint_dict["model"]
self.d_opt_dict = checkpoint_dict["optimizer"]
print("load model_path={},{},sp={}".format(base_g_model, base_d_model, time.time() - st1))
def get_units(self, source, sr):
source = source.unsqueeze(0).to(self.device)
print("source_shape===>", source.shape)
with torch.inference_mode():
start = time.time()
units = gs_hmodel.units(source)
use_time = time.time() - start
print("hubert use time:{}".format(use_time))
return units
def get_unit_pitch(self, source, sr, tran):
source = torchaudio.functional.resample(source, sr, 16000)
if len(source.shape) == 2 and source.shape[1] >= 2:
source = torch.mean(source, dim=0).unsqueeze(0)
soft = self.get_units(source, sr).squeeze(0).cpu().numpy()
f0_coarse, f0 = get_f0(source.cpu().numpy()[0], soft.shape[0] * 2, tran)
return soft, f0
def train(self, in_wav, epoch_num):
train_dataset = TextAudioSpeakerLoader(in_wav)
train_loader = DataLoader(train_dataset, num_workers=0, shuffle=False, batch_size=12)
net_g, net_d, optim_g, optim_d = self.construct_model()
rank = 0
# 用于训练加速
torch.set_float32_matmul_precision('high')
net_g.train()
net_d.train()
global_step = 0
scaler = GradScaler(enabled=gs_train_config['fp16_run'])
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=gs_train_config['lr_decay'], last_epoch=1)
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=gs_train_config['lr_decay'], last_epoch=1)
# 根据上一次的情况来进行学习率更新
# 思路: loss 下降 学习率增加,loss上升学习率减少
for epoch in tqdm(range(0, epoch_num)):
for batch_idx, items in enumerate(train_loader):
# hubert特征,f0,幅度谱特征,对应音频段波形(384 * hop_length),人声编码[0]
c, f0, spec, y, spk = items
g = spk.cuda(rank, non_blocking=True)
spec, y = spec.cuda(rank, non_blocking=True), y.cuda(rank, non_blocking=True)
c = c.cuda(rank, non_blocking=True)
f0 = f0.cuda(rank, non_blocking=True)
"""
"sampling_rate": 32000,
"filter_length": 1280,
"hop_length": 320,
"win_length": 1280,
"n_mel_channels": 80,
"mel_fmin": 0.0,
"mel_fmax": null
"""
# spec, n_fft, num_mels, sampling_rate, fmin, fmax
mel = spec_to_mel_torch(spec, gs_data_config['filter_length'], gs_data_config['n_mel_channels'],
gs_data_config['sampling_rate'], gs_data_config['mel_fmin'],
gs_data_config['mel_fmax'])
with autocast(enabled=gs_train_config['fp16_run']):
# net_g的输入: hubert特征,f0,幅度谱特征,说话人id,mel谱特征
# net_g的输出:
# 原始波形,批次中每个采样到的帧的位置,批次中幅度谱的有效帧位置,
# 幅度谱编码得到正态分布后随机采样得到的z, z经过标准化流之后得到z_p, hubert特征层得到的正态分布的均值,
# hubert特征层得到的正态分布的标准差(logs_p),幅度谱和人声信息得到的均值(m_q),幅度谱和人声信息得到的标准差(logs_q)
y_hat, ids_slice, z_mask, \
(z, z_p, m_p, logs_p, m_q, logs_q) = net_g(c, f0, spec, g=g, mel=mel)
y_mel = commons.slice_segments(mel, ids_slice,
gs_train_config['segment_size'] // gs_data_config['hop_length'])
y_hat_mel = mel_spectrogram_torch(
y_hat.squeeze(1),
gs_data_config['filter_length'],
gs_data_config['n_mel_channels'],
gs_data_config['sampling_rate'],
gs_data_config['hop_length'],
gs_data_config['win_length'],
gs_data_config['mel_fmin'],
gs_data_config['mel_fmax']
)
y = commons.slice_segments(y, ids_slice * gs_data_config['hop_length'],
gs_train_config['segment_size']) # slice
# Discriminator
y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
with autocast(enabled=False):
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g)
loss_disc_all = loss_disc
optim_d.zero_grad()
scaler.scale(loss_disc_all).backward()
scaler.unscale_(optim_d)
scaler.step(optim_d)
with autocast(enabled=gs_train_config['fp16_run']):
# Generator
y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
with autocast(enabled=False):
# mel谱之间的损失函数,后面是系数,误差越小越好
loss_mel = F.l1_loss(y_mel, y_hat_mel) * gs_train_config['c_mel']
# KL散度,z_p: 幅度谱侧得到的采样值经过标准化流之后的结果,logs_q: 幅度谱侧得到的标准差,m_p:hubert侧得到的均值
# logs_p: hubert侧得到的标准差,z_mask: 批次中幅度谱的有效帧位置,
loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * gs_train_config['c_kl']
# 在d模型中将y和y_hat的每一层特征结果都拿出来,做l1距离
loss_fm = feature_loss(fmap_r, fmap_g) * gs_train_config['c_fm']
loss_gen, losses_gen = generator_loss(y_d_hat_g)
loss_gen_all = loss_gen * gs_train_config['c_gen'] + loss_fm + loss_mel + loss_kl
optim_g.zero_grad()
scaler.scale(loss_gen_all).backward()
scaler.unscale_(optim_g)
scaler.step(optim_g)
scaler.update()
if global_step % gs_train_config['log_interval'] == 0:
lr = optim_g.param_groups[0]['lr']
losses_numpy = [round(loss_disc.item(), 3), round(loss_gen.item(), 3),
round(loss_fm.item(), 3), round(loss_mel.item(), 3), round(loss_kl.item(), 3)]
print("gstep={},lr={},disc={},gen={},fm={},mel={},kl={},tot={}".format(global_step, lr,
losses_numpy[0],
losses_numpy[1],
losses_numpy[2],
losses_numpy[3],
losses_numpy[4],
sum(losses_numpy)))
if global_step % 200 == 0:
torch.save(net_g.state_dict(), "data/web_trained_models/xiafan_{}.pth".format(global_step))
global_step += 1
scheduler_g.step()
scheduler_d.step()
return net_g
def infer(self, in_wav, dst_wav, model):
tran = 0 # 变化的音高
source, sr = librosa.load(in_wav, sr=32000, mono=True)
source = torch.tensor(source).unsqueeze(0)
sid = torch.LongTensor([0]).to(self.device).unsqueeze(0)
soft, pitch = self.get_unit_pitch(source, sr, tran)
f0 = torch.FloatTensor(clean_pitch(pitch)).unsqueeze(0).to(self.device)
stn_tst = torch.FloatTensor(soft)
with torch.no_grad():
model.eval()
x_tst = stn_tst.unsqueeze(0).to(self.device)
start = time.time()
x_tst = torch.repeat_interleave(x_tst, repeats=2, dim=1).transpose(1, 2)
audio = model.infer(x_tst, f0=f0, g=sid)[0, 0].data.float()
use_time = time.time() - start
print("vits use time:{}".format(use_time))
# 写入文件
soundfile.write(dst_wav, audio.cpu().numpy(), sr, format='wav')
####### 对外接口,训练并预测
def process_train_and_infer(self, train_media, in_path, dst_path, dst_model_path=None, params={}):
"""
:param train_media: 训练时使用的数据
:param in_path: 待转换的人声信息
:param dst_path: 转换后的文件地址
:param dst_model_path: 是否缓存模型
:return:
"""
# 对train_media转码为32k单声道
- tmp_wav = train_media + "_321.wav"
- cmd = "ffmpeg -i {} -ar 32000 -ac 1 -y {}".format(train_media, tmp_wav)
+ tmp_32_wav = train_media + "_321.wav"
+ cmd = "ffmpeg -i {} -ar 32000 -ac 1 -y {}".format(train_media, tmp_32_wav)
os.system(cmd)
- if not os.path.exists(tmp_wav):
+ if not os.path.exists(tmp_32_wav):
return 1
+
+ # 做降噪
+ tmp_wav = train_media + "_de321.wav"
+ cmd = "{} {} {}".format(gs_denoise_exe, tmp_32_wav, tmp_wav)
+ os.system(cmd)
+ if not os.path.exists(tmp_wav):
+ os.unlink(tmp_32_wav)
+ return 2
+
in_wav_tmp = in_path + "_321.wav"
cmd = "ffmpeg -i {} -ar 32000 -ac 1 -y {}".format(in_path, in_wav_tmp)
os.system(cmd)
if not os.path.exists(in_wav_tmp):
+ os.unlink(tmp_32_wav)
os.unlink(tmp_wav)
- return 2
+ return 3
global gs_train_config
max_step = params.get('max_step', 200)
gs_train_config['c_mel'] = params.get("c_mel", 45)
gs_train_config['c_fm'] = params.get("c_fm", 1.0)
gs_train_config['c_gen'] = params.get("c_gen", 1.0)
print("params:{}".format(params))
st = time.time()
model = self.train(tmp_wav, max_step)
print("train sp={}".format(time.time() - st))
st = time.time()
self.infer(in_wav_tmp, dst_path, model)
print("infer sp={}".format(time.time() - st))
if dst_model_path is not None:
st = time.time()
torch.save(model.state_dict(), dst_model_path)
print("save model sp={}".format(time.time() - st))
+ os.unlink(tmp_32_wav)
os.unlink(tmp_wav)
os.unlink(in_wav_tmp)
return 0
# 推理结果
def process_infer(self, model_path, in_path, dst_path):
net_g = SynthesizerTrn(
gs_data_config["filter_length"] // 2 + 1,
gs_train_config["segment_size"] // gs_data_config["hop_length"],
**gs_model_config,
no_flow=False,
use_v3=False).cuda()
model_dict = torch.load(model_path, map_location='cpu')
net_g.load_state_dict(model_dict)
in_wav_tmp = in_path + "_321.wav"
cmd = "ffmpeg -i {} -ar 32000 -ac 1 -y {}".format(in_path, in_wav_tmp)
os.system(cmd)
if not os.path.exists(in_wav_tmp):
return 2
self.infer(in_wav_tmp, dst_path, net_g)
- def get_f0(self, vocal_path):
- get_f0()
-
if __name__ == '__main__':
pp = "data/train_users/qiankun_v1/vocals/speaker0/qiankun.wav"
in_p = "data/test/vocal_32.wav"
dst_p = "data/test/vocal_32_out.wav"
dst_m_p = "data/test/mm.pth"
g_path = "data/online_models/models/base_model/sunyanzi_base_2000.pth"
d_path = "data/online_models/models/base_model/sunyanzi_base_d_2000.pth"
svsot = SoVitsSVCOnlineTrain(g_path, d_path)
start_time = time.time()
ret = svsot.process_train_and_infer(pp, in_p, dst_p, dst_m_p)
print("process = {} ret={}".format(time.time() - start_time, ret))
File Metadata
Details
Attached
Mime Type
text/x-diff
Expires
Sun, Jan 12, 08:32 (1 d, 15 h)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
1347184
Default Alt Text
(91 KB)
Attached To
R350 av_svc
Event Timeline
Log In to Comment