diff --git a/AutoCoverTool/online/inference_worker.py b/AutoCoverTool/online/inference_worker.py index 22d7a3c..50d502b 100644 --- a/AutoCoverTool/online/inference_worker.py +++ b/AutoCoverTool/online/inference_worker.py @@ -1,240 +1,240 @@ """ 离线worker 数据库字段要求: // 其中state的状态 // 0:默认,1:被取走,<0异常情况,2完成 // 超时到一定程度也会被重新放回来 数据库格式: id,song_id,url,state,svc_url,create_time,update_time,gender 启动时的环境要求: export PATH=$PATH:/data/gpu_env_common/env/bin/ffmpeg/bin export PYTHONPATH=$PWD:$PWD/ref/music_remover/demucs:$PWD/ref/so_vits_svc:$PWD/ref/split_dirty_frame """ import os import shutil import logging import multiprocessing as mp from online.inference_one import * from online.common import * gs_actw_err_code_download_err = 10001 gs_actw_err_code_trans_err = 10002 gs_actw_err_code_upload_err = 10003 gs_state_default = 0 gs_state_use = 1 gs_state_finish = 2 GS_REGION = "ap-singapore" GS_BUCKET_NAME = "starmaker-sg-1256122840" # GS_COSCMD = "/bin/coscmd" -GS_COSCMD = "coscmd" +GS_COSCMD = "/data/gpu_env_common/env/anaconda3/envs/music_remover_env/bin/coscmd" GS_RES_DIR = "/data/gpu_env_common/res" GS_CONFIG_PATH = os.path.join(GS_RES_DIR, ".online_cos.conf") def exec_cmd(cmd): ret = os.system(cmd) if ret != 0: return False return True def exec_cmd_and_result(cmd): r = os.popen(cmd) text = r.read() r.close() return text def upload_file2cos(key, file_path, region=GS_REGION, bucket_name=GS_BUCKET_NAME): """ 将文件上传到cos :param key: 桶上的具体地址 :param file_path: 本地文件地址 :param region: 区域 :param bucket_name: 桶地址 :return: """ cmd = "{} -c {} -r {} -b {} upload {} {}".format(GS_COSCMD, GS_CONFIG_PATH, region, bucket_name, file_path, key) print(cmd) if exec_cmd(cmd): cmd = "{} -c {} -r {} -b {} info {}".format(GS_COSCMD, GS_CONFIG_PATH, region, bucket_name, key) \ + "| grep Content-Length |awk \'{print $2}\'" res_str = exec_cmd_and_result(cmd) logging.info("{},res={}".format(key, res_str)) size = float(res_str) if size > 0: return True return False return False def post_process_err_callback(msg): print("ERROR|post_process|task_error_callback:", msg) def effect(queue, finish_queue): """ 1. 添加音效 2. 混音 3. 上传到服务端 :return: """ inst = SongCoverInference() while True: logging.info("effect start get...") data = queue.get() song_id, work_dir, svc_file, gender = data logging.info("effect:{},{},{},{}".format(song_id, work_dir, svc_file, gender)) err, effect_file = inst.effect(song_id, work_dir, svc_file) msg = [song_id, err, svc_file, effect_file, gender] logging.info("effect,finish:cid={},state={},svc_file={},effect_file={},gender={}". \ format(song_id, err, svc_file, effect_file, gender)) finish_queue.put(msg) class AutoCoverToolWorker: def __init__(self): self.base_dir = "/tmp" self.work_dir = "" self.inst = SongCoverInference() def update_state(self, song_id, state): sql = "update svc_queue_table set state={},update_time={} where song_id = {}". \ format(state, int(time.time()), song_id) banned_user_map['db'] = "av_db" update_db(sql, banned_user_map) def get_one_data(self): sql = "select song_id, url from svc_queue_table where state = 0 and song_src=1 order by create_time desc limit 1" banned_user_map["db"] = "av_db" data = get_data_by_mysql(sql, banned_user_map) if len(data) == 0: return None, None song_id, song_url = data[0] if song_id != "": self.update_state(song_id, gs_state_use) return str(song_id), song_url def pre_process(self, work_dir, song_url): """ 创建文件夹,下载数据 :return: """ ext = str(song_url).split(".")[-1] dst_file = "{}/src_origin.{}".format(work_dir, ext) cmd = "wget {} -O {}".format(song_url, dst_file) print(cmd) os.system(cmd) if not os.path.exists(dst_file): return gs_actw_err_code_download_err dst_mp3_file = "{}/src.mp3".format(work_dir) cmd = "ffmpeg -i {} -ar 44100 -ac 2 -y {} ".format(dst_file, dst_mp3_file) os.system(cmd) if not os.path.exists(dst_mp3_file): return gs_actw_err_code_trans_err return gs_err_code_success def post_process(self, msg): song_id, err, svc_file, effect_file, gender = msg work_dir = os.path.join(self.base_dir, str(song_id)) if err != gs_err_code_success: self.update_state(song_id, -err) return # 替换和混音 err, mix_path_mp3 = self.inst.mix(song_id, work_dir, svc_file, effect_file) logging.info( "post_process:song_id={},work_dir={},svc_file={},gender={}".format(song_id, work_dir, svc_file, gender)) svc_url = None state = gs_state_finish if err != gs_err_code_success: state = -err else: # 上传到cos mix_name = os.path.basename(mix_path_mp3) key = "av_res/svc_res/{}".format(mix_name) if not upload_file2cos(key, mix_path_mp3): state = -err else: state = gs_state_finish svc_url = key logging.info("upload_file2cos:song_id={},key={},mix_path_mp3={}".format(song_id, key, mix_path_mp3)) # 更新数据库 if state != gs_state_finish: self.update_state(song_id, state) return sql = "update svc_queue_table set state={},update_time={},svc_url=\"{}\",gender={} where song_id = {}". \ format(gs_state_finish, int(time.time()), svc_url, gender, song_id) logging.info("post_process:song_id={},sql={}".format(song_id, sql)) banned_user_map['db'] = "av_db" update_db(sql, banned_user_map) def process(self): logging.info("start_process....") worker_num = 4 worker_queue = mp.Manager().Queue(maxsize=int(worker_num * 1.5)) finish_queue = mp.Manager().Queue(maxsize=int(worker_num * 1.5)) pool = mp.Pool(processes=worker_num) for i in range(worker_num): pool.apply_async(effect, args=(worker_queue, finish_queue), error_callback=post_process_err_callback) while True: # 将堆积的内容处理一遍 while finish_queue.qsize() > 0: msg = finish_queue.get(timeout=1) self.post_process(msg) song_id, err, svc_file, effect_file, gender = msg work_dir = os.path.join(self.base_dir, str(song_id)) logging.info("clear = song_id={},work_dir={}".format(song_id, work_dir)) shutil.rmtree(work_dir) song_id, song_url = self.get_one_data() logging.info("\n\nget_one_data = {},{}".format(song_id, song_url)) if song_id is None: time.sleep(5) continue # 创建空间 work_dir = os.path.join(self.base_dir, str(song_id)) if os.path.exists(work_dir): shutil.rmtree(work_dir) os.makedirs(work_dir) logging.info("song_id={},work_dir={},finish".format(song_id, work_dir)) # 预处理 err = self.pre_process(work_dir, song_url) if err != gs_err_code_success: self.update_state(song_id, -err) shutil.rmtree(work_dir) continue logging.info("song_id={},work_dir={},pre_process".format(song_id, work_dir)) # 获取svc数据 err, svc_file = self.inst.generate_svc_file(song_id, work_dir) if err != gs_err_code_success: self.update_state(song_id, -err) shutil.rmtree(work_dir) continue logging.info("song_id={},work_dir={},generate_svc_file".format(song_id, work_dir)) # 做音效处理的异步代码 gender = self.inst.get_gender(svc_file) worker_queue.put([song_id, work_dir, svc_file, gender]) logging.info("song_id={},work_dir={},svc_file={},gender={}".format(song_id, work_dir, svc_file, gender)) pool.close() pool.join() if __name__ == '__main__': actw = AutoCoverToolWorker() actw.process()