Page MenuHomePhabricator

No OneTemporary

diff --git a/AutoCoverTool/online/train_users_model.py b/AutoCoverTool/online/train_users_model.py
index 5a2c9d1..781ce27 100644
--- a/AutoCoverTool/online/train_users_model.py
+++ b/AutoCoverTool/online/train_users_model.py
@@ -1,316 +1,315 @@
"""
训练人声模型
输入: 训练使用的音频(干声文件)
输出: 训练出的模型(训练1000轮次)、测试音频的生成结果
"""
import os
import time
import json
import glob
import shutil
import GPUtil
from script.train_user_by_one_media import SoVitsSVCOnlineTrain
from ref.online.voice_class_online import VoiceClass
from online.beanstalk_helper import BeanstalkHelper
from online.common import update_db, get_data_by_mysql, upload_file2cos, exec_cmd, get_all_shared_data_by_sql
# cos资源目录: av-audit-sync-bj-1256122840
gs_cos_dir = "av_res/so_vits_models/train_users/{user_id}" # 内部结构: record_id.mp4[干声],test_svc_file.m4a
gs_record_path = os.path.join(gs_cos_dir, "{record_id}.mp4")
gs_svc_path = os.path.join(gs_cos_dir, "example.m4a")
gs_model_path = os.path.join("av_res/so_vits_models/3.0_v1/", "{user_id}_{gender}.pth")
gs_bean_config = {"addr": "sg-test-common-box-1:11300", "consumer": "auto_user_svc_trainer_v1"}
# gs_coscmd = "/data/gpu_env_common/env/anaconda3/envs/auto_song_cover/bin/coscmd"
-gs_coscmd = "/bin/coscmd"
-gs_config_path = "/home/normal/.cos.conf"
+gs_coscmd = "/data/anaconda3/envs/auto_song_cover/bin/coscmd"
+gs_config_path = "/data/gpu_env_common/bin/cos.conf"
if GPUtil.getGPUs()[0].name == "Tesla T4":
gs_coscmd = "/data/gpu_env_common/env/anaconda3/envs/auto_song_cover_t4/bin/coscmd"
- gs_config_path = "/data/gpu_env_common/bin/cos.conf"
gs_model_dir = "/data/gpu_env_common/res/av_svc/models"
gs_cache_dir = "/tmp/train_users_model_cache_dir"
gs_download_cache_dir = "/tmp/train_users_model_download_cache_dir"
gs_region = "ap-beijing"
gs_bucket_name = "av-audit-sync-bj-1256122840"
gs_tum_err_success = 0
gs_tum_err_already_processed = 1
gs_tum_err_download_origin = 2
gs_tum_err_train_and_inf = 3
gs_tum_err_test_media_transcode = 4
gs_tum_err_gender_err = 5
def get_d(audio_path):
cmd = "ffprobe -v quiet -print_format json -show_format -show_streams {}".format(audio_path)
data = exec_cmd(cmd)
data = json.loads(data)
if "format" in data.keys():
if "duration" in data['format']:
return float(data["format"]["duration"])
return 0
def download_url(url, dst_path):
if url != "":
cmd = "wget {} -O {}".format(url, dst_path)
os.system(cmd)
return os.path.exists(dst_path)
def update_gender(user_id, gender):
"""
查看数据库,只有当性别是3[未知]再更新
:return:
"""
sql = "select * from av_db.av_svc_model where user_id=\"{}\" and gender=3".format(user_id)
data = get_data_by_mysql(sql)
if len(data) == 1:
sql = "update av_db.av_svc_model set gender={} where user_id=\"{}\"".format(gender, user_id)
update_db(sql)
class StatHelper:
"""
状态更新类
"""
def __init__(self, user_id, record_id):
self.user_id = user_id
self.record_id = record_id
def is_processed(self):
"""
库里有数据,并且不是-1,1,2,3意味着可以处理
处理之后更新为3,后面人工听过没问题,设置为1/2
:return:
"""
sql = "select * from av_db.av_svc_model where user_id = \"{}\" and gender not in (-1, 1, 2, 3)".format(
self.user_id)
data = get_data_by_mysql(sql)
return len(data) > 0
def update_model_url(self, local_model_path, gender, test_svc_m4a):
"""
上传到cos上,并更新数据库
:param local_model_path:
:param gender:
:param test_svc_m4a:
:return:
"""
if not self.is_processed():
# 先上传到cos,然后再更新数据库
key = gs_model_path.format(user_id=self.user_id, gender=gender)
key_svc_m4a = gs_svc_path.format(user_id=self.user_id)
if upload_file2cos(key, local_model_path, gs_region, gs_bucket_name, gs_coscmd,
gs_config_path) \
and upload_file2cos(key_svc_m4a, test_svc_m4a, gs_region, gs_bucket_name, gs_coscmd,
gs_config_path):
sql = "insert into av_db.av_svc_model (model_version, user_id, model_url, gender) values (\"sovits3.0_v1\", \"{}\", \"{}\", 3)".format(
self.user_id, key)
update_db(sql)
class TrainUserModelOnline:
def __init__(self):
g_path = os.path.join(gs_model_dir, 'sunyanzi_base_2000.pth')
d_path = os.path.join(gs_model_dir, 'sunyanzi_base_d_2000.pth')
self.test_svc_src_file = os.path.join(gs_model_dir, "../syz_test.wav")
self.ssot_inst = SoVitsSVCOnlineTrain(g_path, d_path)
self.cache_dir = None
def download_file(self, user_id, record_id):
filepath = gs_record_path.format(user_id=user_id, record_id=record_id)
local_path = os.path.join(self.cache_dir, "{}.mp4".format(record_id))
cmd = "{} -c {} -r {} -b {} download {} {}".format(gs_coscmd, gs_config_path, gs_region, gs_bucket_name,
filepath, local_path)
exec_cmd(cmd)
return os.path.exists(local_path)
def process_one(self, user_id, record_id, gender):
"""
1. 检查是否处理过
2. 下载数据
3. 训练并生成测试干声
4. 上传并更新数据库
:param user_id:
:param record_id:
:param gender:
:return:
"""
sh_inst = StatHelper(user_id, record_id)
# 检查文件是否处理过
if sh_inst.is_processed():
print("train_user_model,user_id={},record_id={}, already_processed".format(user_id, record_id))
return gs_tum_err_already_processed
# 下载干声文件
if not self.download_file(user_id, record_id):
print("train_user_model,user_id={},record_id={}, download_file err".format(user_id, record_id))
return gs_tum_err_download_origin
origin_path = os.path.join(self.cache_dir, "{}.mp4".format(record_id))
# 训练并生成干声
dst_path = os.path.join(self.cache_dir, "example.wav")
dst_m4a_path = os.path.join(self.cache_dir, "example.m4a")
dst_model_path = os.path.join(self.cache_dir, "{}_{}.pth".format(user_id, gender))
ret = self.ssot_inst.process_train_and_infer(origin_path, self.test_svc_src_file, dst_path, dst_model_path,
{"max_step": 1000})
if ret != 0:
print("train_user_model,user_id={},record_id={}, process_train_and_infer err={}".format(user_id, record_id,
ret))
return gs_tum_err_train_and_inf
# 对音频做转码
cmd = "ffmpeg -i {} -ar 32000 -ac 1 {}".format(dst_path, dst_m4a_path)
print(cmd)
os.system(cmd)
if not os.path.exists(dst_m4a_path):
print("train_user_model,user_id={},record_id={}, transcode".format(user_id, record_id))
return gs_tum_err_test_media_transcode
sh_inst.update_model_url(dst_model_path, gender, dst_m4a_path)
return gs_tum_err_success
def process(self):
bean_helper = BeanstalkHelper(gs_bean_config)
bean = bean_helper.get_beanstalkd()
bean.watch(gs_bean_config["consumer"])
if not os.path.exists(gs_cache_dir):
os.makedirs(gs_cache_dir)
while True:
payload = bean.reserve(5)
print(payload)
if not payload:
print("bean sleep...")
continue
in_data = json.loads(payload.body)
user_id = in_data["user_id"] # 包括user_id/record_id.mp4
record_id = in_data["record_id"]
gender = in_data["gender"]
self.cache_dir = os.path.join(gs_cache_dir, "{}".format(user_id))
if os.path.exists(self.cache_dir):
shutil.rmtree(self.cache_dir)
os.makedirs(self.cache_dir)
ret = "exp"
try:
ret = self.process_one(user_id, record_id, gender)
except Exception as ex:
print("train_user_model,user_id={},record_id={}, ex={}".format(user_id, record_id, ex))
print("train_user_model,user_id={},record_id={}, ret={}".format(user_id, record_id, ret))
payload.delete()
if os.path.exists(self.cache_dir):
shutil.rmtree(self.cache_dir)
continue
class GenerateUserData:
"""
从数据库筛选出合适的数据
数据要求:
1. 干声
2. 检查出来是纯人声
3. 上传到北京的cos
"""
def __init__(self):
music_voice_pure_model = os.path.join(gs_model_dir, "voice_005_rec_v5.pth")
music_voice_no_pure_model = os.path.join(gs_model_dir, "voice_10_v5.pth")
gender_pure_model = os.path.join(gs_model_dir, "gender_8k_ratev5_v6_adam.pth")
gender_no_pure_model = os.path.join(gs_model_dir, "gender_8k_v6_adam.pth")
self.voice_class_inst = VoiceClass(music_voice_pure_model, music_voice_no_pure_model, gender_pure_model,
gender_no_pure_model)
def process_one(self, origin_path):
# 获取性别信息
gender_map = [2, 1, 3]
gender, rate, is_pure = self.voice_class_inst.process(origin_path) # 0女,1男,2未知 | 对外要求: 1男2女3未知
if not is_pure or rate == -1 or gender not in (0, 1, 2):
print("train_user_model,gender check, res={}".format([gender, rate, is_pure]))
return gs_tum_err_gender_err
gender = gender_map[gender]
return gender
def get_user_ids_from_db(self):
"""
从数据库获取数据
:return:
"""
bean_helper = BeanstalkHelper(gs_bean_config)
cur_time = time.time() - 3600 * 4
sql = """
select id, user_id,recording_url
from recording
where created_on > {cur_time}
and grade in ('A++', 'A+')
and is_public = 1 and is_deleted = 0 and media_type in (1, 2, 3, 4, 9, 10)
""".format(cur_time=cur_time)
data = get_all_shared_data_by_sql(sql)
if not os.path.exists(gs_download_cache_dir):
os.makedirs(gs_download_cache_dir)
user_in_bean_cnt = 0
user_cnt = 0
user_ids_in_bean = []
st = time.time()
for record_id, user_id, record_url in data:
if user_id in user_ids_in_bean:
continue
origin_path = os.path.join(gs_download_cache_dir, "{}_{}.mp4".format(user_id, record_id))
record_url = str(record_url).replace("master.mp4", "origin_master.mp4")
if download_url(record_url, origin_path):
gender = self.process_one(origin_path)
if gender in (0, 1, 2):
# 将文件上传到cos
key = gs_record_path.format(user_id=user_id, record_id=record_id)
if upload_file2cos(key, origin_path, gs_region, gs_bucket_name, gs_coscmd, gs_config_path):
message = json.dumps(
{"user_id": str(user_id), "record_id": str(record_id), "gender": int(gender)})
bean_helper.put_payload_to_beanstalk(gs_bean_config["consumer"], message, ttr=2 * 86400)
user_ids_in_bean.append(user_id)
print("put msg to db={}".format(message))
user_in_bean_cnt += 1
continue
if user_cnt % 1000 == 0:
print("{}/{}/{} sp={}".format(user_in_bean_cnt, user_cnt, len(data), time.time() - st))
user_cnt += 1
print("finish {}/{}/{} sp={}".format(user_in_bean_cnt, user_cnt, len(data), time.time() - st))
def download_model_and_update_gender(self, arr):
st = time.time()
for idx, msg in enumerate(arr):
user_id, model_url, gender = msg
model_name = str(model_url).split("/")[-1]
local_path = os.path.join("/data/prod/so_vits_models/3.0/unknown/{}".format(model_name))
cmd = "{} -c {} -r {} -b {} download {} -f {}".format(gs_coscmd, gs_config_path, gs_region, gs_bucket_name,
model_url, local_path)
print(cmd)
exec_cmd(cmd)
if os.path.exists(local_path):
# 更新数据
update_gender(user_id, gender)
if idx % 10 == 0:
print("{}/{} sp={}".format(idx, len(arr), time.time() - st))
print("finish {} sp={}".format(len(arr), time.time() - st))
def put_data2bean():
bean_helper = BeanstalkHelper(gs_bean_config)
message = json.dumps({"user_id": "3096224751941488", "record_id": "3096224802670844"})
bean_helper.put_payload_to_beanstalk(gs_bean_config["consumer"], message, ttr=2 * 86400)
print("input_finish....")
if __name__ == '__main__':
tumo_inst = TrainUserModelOnline()
tumo_inst.process()
# gud_inst = GenerateUserData()
# gud_inst.get_user_ids_from_db()

File Metadata

Mime Type
text/x-diff
Expires
Sun, Jan 12, 08:32 (1 d, 15 h)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
1347186
Default Alt Text
(13 KB)

Event Timeline