diff --git a/AutoCoverTool/online/train_user_get_data.py b/AutoCoverTool/online/train_user_get_data.py index 77278d2..a18101f 100644 --- a/AutoCoverTool/online/train_user_get_data.py +++ b/AutoCoverTool/online/train_user_get_data.py @@ -1,18 +1,40 @@ from online.train_users_model import GenerateUserData """ 成对的数据,其中本代码用于从服务器获取数据,train_users_model用于在北京训练模型 操作分为三部分: 1. python3 train_user_get_data.py | 生成数据 2. python3 train_users_model.py 3. python3 train_users_download_model.py | 下载数据 """ if __name__ == '__main__': gud_inst = GenerateUserData() # gud_inst.get_user_ids_from_db() arr = [ - ["1741807722", "av_res/so_vits_models/3.0_v1/1741807722_1.pth", 2] + ["1774037545", "av_res/so_vits_models/3.0_v1/1774037545_2.pth", 2], + ["3634613822", "av_res/so_vits_models/3.0_v1/3634613822_2.pth", 2], + ["3634620621", "av_res/so_vits_models/3.0_v1/3634620621_1.pth", 1], + ["3634762753", "av_res/so_vits_models/3.0_v1/3634762753_1.pth", 1], + ["3634838529", "av_res/so_vits_models/3.0_v1/3634838529_2.pth", 2], + ["3635098559", "av_res/so_vits_models/3.0_v1/3635098559_1.pth", 1], + ["3635291895", "av_res/so_vits_models/3.0_v1/3635291895_1.pth", 1], + ["3635456421", "av_res/so_vits_models/3.0_v1/3635456421_2.pth", 2], + ["3636200631", "av_res/so_vits_models/3.0_v1/3636200631_2.pth", 2], + ["3636252146", "av_res/so_vits_models/3.0_v1/3636252146_2.pth", 2], + ["3636504073", "av_res/so_vits_models/3.0_v1/3636504073_2.pth", 2], + ["3637044390", "av_res/so_vits_models/3.0_v1/3637044390_2.pth", 2], + ["3637718089", "av_res/so_vits_models/3.0_v1/3637718089_1.pth", 1], + ["3638116992", "av_res/so_vits_models/3.0_v1/3638116992_1.pth", 1], + ["3638859695", "av_res/so_vits_models/3.0_v1/3638859695_1.pth", 1], + ["3639430202", "av_res/so_vits_models/3.0_v1/3639430202_2.pth", 2], + ["3639532516", "av_res/so_vits_models/3.0_v1/3639532516_2.pth", -1], + ["3639734898", "av_res/so_vits_models/3.0_v1/3639734898_1.pth", 1], + ["3639781478", "av_res/so_vits_models/3.0_v1/3639781478_1.pth", 1], + ["3639790461", "av_res/so_vits_models/3.0_v1/3639790461_1.pth", 1], + ["3639851605", "av_res/so_vits_models/3.0_v1/3639851605_2.pth", -1], + ["3640773056", "av_res/so_vits_models/3.0_v1/3640773056_1.pth", 1], + ["3640842366", "av_res/so_vits_models/3.0_v1/3640842366_2.pth", -1], ] gud_inst.download_model_and_update_gender(arr) diff --git a/AutoCoverTool/online/train_users_model.py b/AutoCoverTool/online/train_users_model.py index be90fd2..5a2c9d1 100644 --- a/AutoCoverTool/online/train_users_model.py +++ b/AutoCoverTool/online/train_users_model.py @@ -1,315 +1,316 @@ """ 训练人声模型 输入: 训练使用的音频(干声文件) 输出: 训练出的模型(训练1000轮次)、测试音频的生成结果 """ import os import time import json import glob import shutil import GPUtil from script.train_user_by_one_media import SoVitsSVCOnlineTrain from ref.online.voice_class_online import VoiceClass from online.beanstalk_helper import BeanstalkHelper from online.common import update_db, get_data_by_mysql, upload_file2cos, exec_cmd, get_all_shared_data_by_sql # cos资源目录: av-audit-sync-bj-1256122840 gs_cos_dir = "av_res/so_vits_models/train_users/{user_id}" # 内部结构: record_id.mp4[干声],test_svc_file.m4a gs_record_path = os.path.join(gs_cos_dir, "{record_id}.mp4") gs_svc_path = os.path.join(gs_cos_dir, "example.m4a") gs_model_path = os.path.join("av_res/so_vits_models/3.0_v1/", "{user_id}_{gender}.pth") gs_bean_config = {"addr": "sg-test-common-box-1:11300", "consumer": "auto_user_svc_trainer_v1"} # gs_coscmd = "/data/gpu_env_common/env/anaconda3/envs/auto_song_cover/bin/coscmd" gs_coscmd = "/bin/coscmd" gs_config_path = "/home/normal/.cos.conf" if GPUtil.getGPUs()[0].name == "Tesla T4": gs_coscmd = "/data/gpu_env_common/env/anaconda3/envs/auto_song_cover_t4/bin/coscmd" gs_config_path = "/data/gpu_env_common/bin/cos.conf" gs_model_dir = "/data/gpu_env_common/res/av_svc/models" gs_cache_dir = "/tmp/train_users_model_cache_dir" gs_download_cache_dir = "/tmp/train_users_model_download_cache_dir" gs_region = "ap-beijing" gs_bucket_name = "av-audit-sync-bj-1256122840" gs_tum_err_success = 0 gs_tum_err_already_processed = 1 gs_tum_err_download_origin = 2 gs_tum_err_train_and_inf = 3 gs_tum_err_test_media_transcode = 4 gs_tum_err_gender_err = 5 def get_d(audio_path): cmd = "ffprobe -v quiet -print_format json -show_format -show_streams {}".format(audio_path) data = exec_cmd(cmd) data = json.loads(data) if "format" in data.keys(): if "duration" in data['format']: return float(data["format"]["duration"]) return 0 def download_url(url, dst_path): if url != "": cmd = "wget {} -O {}".format(url, dst_path) os.system(cmd) return os.path.exists(dst_path) def update_gender(user_id, gender): """ 查看数据库,只有当性别是3[未知]再更新 :return: """ sql = "select * from av_db.av_svc_model where user_id=\"{}\" and gender=3".format(user_id) data = get_data_by_mysql(sql) if len(data) == 1: sql = "update av_db.av_svc_model set gender={} where user_id=\"{}\"".format(gender, user_id) update_db(sql) class StatHelper: """ 状态更新类 """ def __init__(self, user_id, record_id): self.user_id = user_id self.record_id = record_id def is_processed(self): """ 库里有数据,并且不是-1,1,2,3意味着可以处理 处理之后更新为3,后面人工听过没问题,设置为1/2 :return: """ sql = "select * from av_db.av_svc_model where user_id = \"{}\" and gender not in (-1, 1, 2, 3)".format( self.user_id) data = get_data_by_mysql(sql) return len(data) > 0 def update_model_url(self, local_model_path, gender, test_svc_m4a): """ 上传到cos上,并更新数据库 :param local_model_path: :param gender: :param test_svc_m4a: :return: """ if not self.is_processed(): # 先上传到cos,然后再更新数据库 key = gs_model_path.format(user_id=self.user_id, gender=gender) key_svc_m4a = gs_svc_path.format(user_id=self.user_id) if upload_file2cos(key, local_model_path, gs_region, gs_bucket_name, gs_coscmd, gs_config_path) \ and upload_file2cos(key_svc_m4a, test_svc_m4a, gs_region, gs_bucket_name, gs_coscmd, gs_config_path): sql = "insert into av_db.av_svc_model (model_version, user_id, model_url, gender) values (\"sovits3.0_v1\", \"{}\", \"{}\", 3)".format( self.user_id, key) update_db(sql) class TrainUserModelOnline: def __init__(self): g_path = os.path.join(gs_model_dir, 'sunyanzi_base_2000.pth') d_path = os.path.join(gs_model_dir, 'sunyanzi_base_d_2000.pth') self.test_svc_src_file = os.path.join(gs_model_dir, "../syz_test.wav") self.ssot_inst = SoVitsSVCOnlineTrain(g_path, d_path) self.cache_dir = None def download_file(self, user_id, record_id): filepath = gs_record_path.format(user_id=user_id, record_id=record_id) local_path = os.path.join(self.cache_dir, "{}.mp4".format(record_id)) cmd = "{} -c {} -r {} -b {} download {} {}".format(gs_coscmd, gs_config_path, gs_region, gs_bucket_name, filepath, local_path) exec_cmd(cmd) return os.path.exists(local_path) def process_one(self, user_id, record_id, gender): """ 1. 检查是否处理过 2. 下载数据 3. 训练并生成测试干声 4. 上传并更新数据库 :param user_id: :param record_id: :param gender: :return: """ sh_inst = StatHelper(user_id, record_id) # 检查文件是否处理过 if sh_inst.is_processed(): print("train_user_model,user_id={},record_id={}, already_processed".format(user_id, record_id)) return gs_tum_err_already_processed # 下载干声文件 if not self.download_file(user_id, record_id): print("train_user_model,user_id={},record_id={}, download_file err".format(user_id, record_id)) return gs_tum_err_download_origin origin_path = os.path.join(self.cache_dir, "{}.mp4".format(record_id)) # 训练并生成干声 dst_path = os.path.join(self.cache_dir, "example.wav") dst_m4a_path = os.path.join(self.cache_dir, "example.m4a") dst_model_path = os.path.join(self.cache_dir, "{}_{}.pth".format(user_id, gender)) ret = self.ssot_inst.process_train_and_infer(origin_path, self.test_svc_src_file, dst_path, dst_model_path, {"max_step": 1000}) if ret != 0: print("train_user_model,user_id={},record_id={}, process_train_and_infer err={}".format(user_id, record_id, ret)) return gs_tum_err_train_and_inf # 对音频做转码 cmd = "ffmpeg -i {} -ar 32000 -ac 1 {}".format(dst_path, dst_m4a_path) print(cmd) os.system(cmd) if not os.path.exists(dst_m4a_path): print("train_user_model,user_id={},record_id={}, transcode".format(user_id, record_id)) return gs_tum_err_test_media_transcode sh_inst.update_model_url(dst_model_path, gender, dst_m4a_path) return gs_tum_err_success def process(self): bean_helper = BeanstalkHelper(gs_bean_config) bean = bean_helper.get_beanstalkd() bean.watch(gs_bean_config["consumer"]) if not os.path.exists(gs_cache_dir): os.makedirs(gs_cache_dir) while True: payload = bean.reserve(5) print(payload) if not payload: print("bean sleep...") continue in_data = json.loads(payload.body) user_id = in_data["user_id"] # 包括user_id/record_id.mp4 record_id = in_data["record_id"] gender = in_data["gender"] self.cache_dir = os.path.join(gs_cache_dir, "{}".format(user_id)) if os.path.exists(self.cache_dir): shutil.rmtree(self.cache_dir) os.makedirs(self.cache_dir) ret = "exp" try: ret = self.process_one(user_id, record_id, gender) except Exception as ex: print("train_user_model,user_id={},record_id={}, ex={}".format(user_id, record_id, ex)) print("train_user_model,user_id={},record_id={}, ret={}".format(user_id, record_id, ret)) payload.delete() if os.path.exists(self.cache_dir): shutil.rmtree(self.cache_dir) continue class GenerateUserData: """ 从数据库筛选出合适的数据 数据要求: 1. 干声 2. 检查出来是纯人声 3. 上传到北京的cos """ def __init__(self): music_voice_pure_model = os.path.join(gs_model_dir, "voice_005_rec_v5.pth") music_voice_no_pure_model = os.path.join(gs_model_dir, "voice_10_v5.pth") gender_pure_model = os.path.join(gs_model_dir, "gender_8k_ratev5_v6_adam.pth") gender_no_pure_model = os.path.join(gs_model_dir, "gender_8k_v6_adam.pth") self.voice_class_inst = VoiceClass(music_voice_pure_model, music_voice_no_pure_model, gender_pure_model, gender_no_pure_model) def process_one(self, origin_path): # 获取性别信息 gender_map = [2, 1, 3] gender, rate, is_pure = self.voice_class_inst.process(origin_path) # 0女,1男,2未知 | 对外要求: 1男2女3未知 if not is_pure or rate == -1 or gender not in (0, 1, 2): print("train_user_model,gender check, res={}".format([gender, rate, is_pure])) return gs_tum_err_gender_err gender = gender_map[gender] return gender def get_user_ids_from_db(self): """ 从数据库获取数据 :return: """ bean_helper = BeanstalkHelper(gs_bean_config) cur_time = time.time() - 3600 * 4 sql = """ select id, user_id,recording_url from recording where created_on > {cur_time} and grade in ('A++', 'A+') and is_public = 1 and is_deleted = 0 and media_type in (1, 2, 3, 4, 9, 10) """.format(cur_time=cur_time) data = get_all_shared_data_by_sql(sql) if not os.path.exists(gs_download_cache_dir): os.makedirs(gs_download_cache_dir) user_in_bean_cnt = 0 user_cnt = 0 user_ids_in_bean = [] st = time.time() for record_id, user_id, record_url in data: if user_id in user_ids_in_bean: continue origin_path = os.path.join(gs_download_cache_dir, "{}_{}.mp4".format(user_id, record_id)) record_url = str(record_url).replace("master.mp4", "origin_master.mp4") if download_url(record_url, origin_path): gender = self.process_one(origin_path) if gender in (0, 1, 2): # 将文件上传到cos key = gs_record_path.format(user_id=user_id, record_id=record_id) if upload_file2cos(key, origin_path, gs_region, gs_bucket_name, gs_coscmd, gs_config_path): message = json.dumps( {"user_id": str(user_id), "record_id": str(record_id), "gender": int(gender)}) bean_helper.put_payload_to_beanstalk(gs_bean_config["consumer"], message, ttr=2 * 86400) user_ids_in_bean.append(user_id) print("put msg to db={}".format(message)) user_in_bean_cnt += 1 continue if user_cnt % 1000 == 0: print("{}/{}/{} sp={}".format(user_in_bean_cnt, user_cnt, len(data), time.time() - st)) user_cnt += 1 print("finish {}/{}/{} sp={}".format(user_in_bean_cnt, user_cnt, len(data), time.time() - st)) def download_model_and_update_gender(self, arr): st = time.time() for idx, msg in enumerate(arr): user_id, model_url, gender = msg model_name = str(model_url).split("/")[-1] local_path = os.path.join("/data/prod/so_vits_models/3.0/unknown/{}".format(model_name)) cmd = "{} -c {} -r {} -b {} download {} -f {}".format(gs_coscmd, gs_config_path, gs_region, gs_bucket_name, - model_url, local_path) + model_url, local_path) print(cmd) exec_cmd(cmd) + if os.path.exists(local_path): + # 更新数据 + update_gender(user_id, gender) if idx % 10 == 0: print("{}/{} sp={}".format(idx, len(arr), time.time() - st)) - # 更新数据 - update_gender(user_id, gender) print("finish {} sp={}".format(len(arr), time.time() - st)) def put_data2bean(): bean_helper = BeanstalkHelper(gs_bean_config) message = json.dumps({"user_id": "3096224751941488", "record_id": "3096224802670844"}) bean_helper.put_payload_to_beanstalk(gs_bean_config["consumer"], message, ttr=2 * 86400) print("input_finish....") if __name__ == '__main__': tumo_inst = TrainUserModelOnline() tumo_inst.process() # gud_inst = GenerateUserData() # gud_inst.get_user_ids_from_db()