Page Menu
Home
Phabricator
Search
Configure Global Search
Log In
Files
F4880313
No One
Temporary
Actions
View File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Flag For Later
Size
13 KB
Subscribers
None
View Options
diff --git a/AutoCoverTool/online/train_users_model.py b/AutoCoverTool/online/train_users_model.py
index 5a2c9d1..781ce27 100644
--- a/AutoCoverTool/online/train_users_model.py
+++ b/AutoCoverTool/online/train_users_model.py
@@ -1,316 +1,315 @@
"""
训练人声模型
输入: 训练使用的音频(干声文件)
输出: 训练出的模型(训练1000轮次)、测试音频的生成结果
"""
import os
import time
import json
import glob
import shutil
import GPUtil
from script.train_user_by_one_media import SoVitsSVCOnlineTrain
from ref.online.voice_class_online import VoiceClass
from online.beanstalk_helper import BeanstalkHelper
from online.common import update_db, get_data_by_mysql, upload_file2cos, exec_cmd, get_all_shared_data_by_sql
# cos资源目录: av-audit-sync-bj-1256122840
gs_cos_dir = "av_res/so_vits_models/train_users/{user_id}" # 内部结构: record_id.mp4[干声],test_svc_file.m4a
gs_record_path = os.path.join(gs_cos_dir, "{record_id}.mp4")
gs_svc_path = os.path.join(gs_cos_dir, "example.m4a")
gs_model_path = os.path.join("av_res/so_vits_models/3.0_v1/", "{user_id}_{gender}.pth")
gs_bean_config = {"addr": "sg-test-common-box-1:11300", "consumer": "auto_user_svc_trainer_v1"}
# gs_coscmd = "/data/gpu_env_common/env/anaconda3/envs/auto_song_cover/bin/coscmd"
-gs_coscmd = "/bin/coscmd"
-gs_config_path = "/home/normal/.cos.conf"
+gs_coscmd = "/data/anaconda3/envs/auto_song_cover/bin/coscmd"
+gs_config_path = "/data/gpu_env_common/bin/cos.conf"
if GPUtil.getGPUs()[0].name == "Tesla T4":
gs_coscmd = "/data/gpu_env_common/env/anaconda3/envs/auto_song_cover_t4/bin/coscmd"
- gs_config_path = "/data/gpu_env_common/bin/cos.conf"
gs_model_dir = "/data/gpu_env_common/res/av_svc/models"
gs_cache_dir = "/tmp/train_users_model_cache_dir"
gs_download_cache_dir = "/tmp/train_users_model_download_cache_dir"
gs_region = "ap-beijing"
gs_bucket_name = "av-audit-sync-bj-1256122840"
gs_tum_err_success = 0
gs_tum_err_already_processed = 1
gs_tum_err_download_origin = 2
gs_tum_err_train_and_inf = 3
gs_tum_err_test_media_transcode = 4
gs_tum_err_gender_err = 5
def get_d(audio_path):
cmd = "ffprobe -v quiet -print_format json -show_format -show_streams {}".format(audio_path)
data = exec_cmd(cmd)
data = json.loads(data)
if "format" in data.keys():
if "duration" in data['format']:
return float(data["format"]["duration"])
return 0
def download_url(url, dst_path):
if url != "":
cmd = "wget {} -O {}".format(url, dst_path)
os.system(cmd)
return os.path.exists(dst_path)
def update_gender(user_id, gender):
"""
查看数据库,只有当性别是3[未知]再更新
:return:
"""
sql = "select * from av_db.av_svc_model where user_id=\"{}\" and gender=3".format(user_id)
data = get_data_by_mysql(sql)
if len(data) == 1:
sql = "update av_db.av_svc_model set gender={} where user_id=\"{}\"".format(gender, user_id)
update_db(sql)
class StatHelper:
"""
状态更新类
"""
def __init__(self, user_id, record_id):
self.user_id = user_id
self.record_id = record_id
def is_processed(self):
"""
库里有数据,并且不是-1,1,2,3意味着可以处理
处理之后更新为3,后面人工听过没问题,设置为1/2
:return:
"""
sql = "select * from av_db.av_svc_model where user_id = \"{}\" and gender not in (-1, 1, 2, 3)".format(
self.user_id)
data = get_data_by_mysql(sql)
return len(data) > 0
def update_model_url(self, local_model_path, gender, test_svc_m4a):
"""
上传到cos上,并更新数据库
:param local_model_path:
:param gender:
:param test_svc_m4a:
:return:
"""
if not self.is_processed():
# 先上传到cos,然后再更新数据库
key = gs_model_path.format(user_id=self.user_id, gender=gender)
key_svc_m4a = gs_svc_path.format(user_id=self.user_id)
if upload_file2cos(key, local_model_path, gs_region, gs_bucket_name, gs_coscmd,
gs_config_path) \
and upload_file2cos(key_svc_m4a, test_svc_m4a, gs_region, gs_bucket_name, gs_coscmd,
gs_config_path):
sql = "insert into av_db.av_svc_model (model_version, user_id, model_url, gender) values (\"sovits3.0_v1\", \"{}\", \"{}\", 3)".format(
self.user_id, key)
update_db(sql)
class TrainUserModelOnline:
def __init__(self):
g_path = os.path.join(gs_model_dir, 'sunyanzi_base_2000.pth')
d_path = os.path.join(gs_model_dir, 'sunyanzi_base_d_2000.pth')
self.test_svc_src_file = os.path.join(gs_model_dir, "../syz_test.wav")
self.ssot_inst = SoVitsSVCOnlineTrain(g_path, d_path)
self.cache_dir = None
def download_file(self, user_id, record_id):
filepath = gs_record_path.format(user_id=user_id, record_id=record_id)
local_path = os.path.join(self.cache_dir, "{}.mp4".format(record_id))
cmd = "{} -c {} -r {} -b {} download {} {}".format(gs_coscmd, gs_config_path, gs_region, gs_bucket_name,
filepath, local_path)
exec_cmd(cmd)
return os.path.exists(local_path)
def process_one(self, user_id, record_id, gender):
"""
1. 检查是否处理过
2. 下载数据
3. 训练并生成测试干声
4. 上传并更新数据库
:param user_id:
:param record_id:
:param gender:
:return:
"""
sh_inst = StatHelper(user_id, record_id)
# 检查文件是否处理过
if sh_inst.is_processed():
print("train_user_model,user_id={},record_id={}, already_processed".format(user_id, record_id))
return gs_tum_err_already_processed
# 下载干声文件
if not self.download_file(user_id, record_id):
print("train_user_model,user_id={},record_id={}, download_file err".format(user_id, record_id))
return gs_tum_err_download_origin
origin_path = os.path.join(self.cache_dir, "{}.mp4".format(record_id))
# 训练并生成干声
dst_path = os.path.join(self.cache_dir, "example.wav")
dst_m4a_path = os.path.join(self.cache_dir, "example.m4a")
dst_model_path = os.path.join(self.cache_dir, "{}_{}.pth".format(user_id, gender))
ret = self.ssot_inst.process_train_and_infer(origin_path, self.test_svc_src_file, dst_path, dst_model_path,
{"max_step": 1000})
if ret != 0:
print("train_user_model,user_id={},record_id={}, process_train_and_infer err={}".format(user_id, record_id,
ret))
return gs_tum_err_train_and_inf
# 对音频做转码
cmd = "ffmpeg -i {} -ar 32000 -ac 1 {}".format(dst_path, dst_m4a_path)
print(cmd)
os.system(cmd)
if not os.path.exists(dst_m4a_path):
print("train_user_model,user_id={},record_id={}, transcode".format(user_id, record_id))
return gs_tum_err_test_media_transcode
sh_inst.update_model_url(dst_model_path, gender, dst_m4a_path)
return gs_tum_err_success
def process(self):
bean_helper = BeanstalkHelper(gs_bean_config)
bean = bean_helper.get_beanstalkd()
bean.watch(gs_bean_config["consumer"])
if not os.path.exists(gs_cache_dir):
os.makedirs(gs_cache_dir)
while True:
payload = bean.reserve(5)
print(payload)
if not payload:
print("bean sleep...")
continue
in_data = json.loads(payload.body)
user_id = in_data["user_id"] # 包括user_id/record_id.mp4
record_id = in_data["record_id"]
gender = in_data["gender"]
self.cache_dir = os.path.join(gs_cache_dir, "{}".format(user_id))
if os.path.exists(self.cache_dir):
shutil.rmtree(self.cache_dir)
os.makedirs(self.cache_dir)
ret = "exp"
try:
ret = self.process_one(user_id, record_id, gender)
except Exception as ex:
print("train_user_model,user_id={},record_id={}, ex={}".format(user_id, record_id, ex))
print("train_user_model,user_id={},record_id={}, ret={}".format(user_id, record_id, ret))
payload.delete()
if os.path.exists(self.cache_dir):
shutil.rmtree(self.cache_dir)
continue
class GenerateUserData:
"""
从数据库筛选出合适的数据
数据要求:
1. 干声
2. 检查出来是纯人声
3. 上传到北京的cos
"""
def __init__(self):
music_voice_pure_model = os.path.join(gs_model_dir, "voice_005_rec_v5.pth")
music_voice_no_pure_model = os.path.join(gs_model_dir, "voice_10_v5.pth")
gender_pure_model = os.path.join(gs_model_dir, "gender_8k_ratev5_v6_adam.pth")
gender_no_pure_model = os.path.join(gs_model_dir, "gender_8k_v6_adam.pth")
self.voice_class_inst = VoiceClass(music_voice_pure_model, music_voice_no_pure_model, gender_pure_model,
gender_no_pure_model)
def process_one(self, origin_path):
# 获取性别信息
gender_map = [2, 1, 3]
gender, rate, is_pure = self.voice_class_inst.process(origin_path) # 0女,1男,2未知 | 对外要求: 1男2女3未知
if not is_pure or rate == -1 or gender not in (0, 1, 2):
print("train_user_model,gender check, res={}".format([gender, rate, is_pure]))
return gs_tum_err_gender_err
gender = gender_map[gender]
return gender
def get_user_ids_from_db(self):
"""
从数据库获取数据
:return:
"""
bean_helper = BeanstalkHelper(gs_bean_config)
cur_time = time.time() - 3600 * 4
sql = """
select id, user_id,recording_url
from recording
where created_on > {cur_time}
and grade in ('A++', 'A+')
and is_public = 1 and is_deleted = 0 and media_type in (1, 2, 3, 4, 9, 10)
""".format(cur_time=cur_time)
data = get_all_shared_data_by_sql(sql)
if not os.path.exists(gs_download_cache_dir):
os.makedirs(gs_download_cache_dir)
user_in_bean_cnt = 0
user_cnt = 0
user_ids_in_bean = []
st = time.time()
for record_id, user_id, record_url in data:
if user_id in user_ids_in_bean:
continue
origin_path = os.path.join(gs_download_cache_dir, "{}_{}.mp4".format(user_id, record_id))
record_url = str(record_url).replace("master.mp4", "origin_master.mp4")
if download_url(record_url, origin_path):
gender = self.process_one(origin_path)
if gender in (0, 1, 2):
# 将文件上传到cos
key = gs_record_path.format(user_id=user_id, record_id=record_id)
if upload_file2cos(key, origin_path, gs_region, gs_bucket_name, gs_coscmd, gs_config_path):
message = json.dumps(
{"user_id": str(user_id), "record_id": str(record_id), "gender": int(gender)})
bean_helper.put_payload_to_beanstalk(gs_bean_config["consumer"], message, ttr=2 * 86400)
user_ids_in_bean.append(user_id)
print("put msg to db={}".format(message))
user_in_bean_cnt += 1
continue
if user_cnt % 1000 == 0:
print("{}/{}/{} sp={}".format(user_in_bean_cnt, user_cnt, len(data), time.time() - st))
user_cnt += 1
print("finish {}/{}/{} sp={}".format(user_in_bean_cnt, user_cnt, len(data), time.time() - st))
def download_model_and_update_gender(self, arr):
st = time.time()
for idx, msg in enumerate(arr):
user_id, model_url, gender = msg
model_name = str(model_url).split("/")[-1]
local_path = os.path.join("/data/prod/so_vits_models/3.0/unknown/{}".format(model_name))
cmd = "{} -c {} -r {} -b {} download {} -f {}".format(gs_coscmd, gs_config_path, gs_region, gs_bucket_name,
model_url, local_path)
print(cmd)
exec_cmd(cmd)
if os.path.exists(local_path):
# 更新数据
update_gender(user_id, gender)
if idx % 10 == 0:
print("{}/{} sp={}".format(idx, len(arr), time.time() - st))
print("finish {} sp={}".format(len(arr), time.time() - st))
def put_data2bean():
bean_helper = BeanstalkHelper(gs_bean_config)
message = json.dumps({"user_id": "3096224751941488", "record_id": "3096224802670844"})
bean_helper.put_payload_to_beanstalk(gs_bean_config["consumer"], message, ttr=2 * 86400)
print("input_finish....")
if __name__ == '__main__':
tumo_inst = TrainUserModelOnline()
tumo_inst.process()
# gud_inst = GenerateUserData()
# gud_inst.get_user_ids_from_db()
File Metadata
Details
Attached
Mime Type
text/x-diff
Expires
Sun, Jan 12, 08:32 (1 d, 15 h)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
1347186
Default Alt Text
(13 KB)
Attached To
R350 av_svc
Event Timeline
Log In to Comment