Page MenuHomePhabricator

No OneTemporary

diff --git a/AutoCoverTool/online/inference_worker.py b/AutoCoverTool/online/inference_worker.py
index 22d7a3c..50d502b 100644
--- a/AutoCoverTool/online/inference_worker.py
+++ b/AutoCoverTool/online/inference_worker.py
@@ -1,240 +1,240 @@
"""
离线worker
数据库字段要求:
// 其中state的状态
// 0:默认,1:被取走,<0异常情况,2完成
// 超时到一定程度也会被重新放回来
数据库格式:
id,song_id,url,state,svc_url,create_time,update_time,gender
启动时的环境要求:
export PATH=$PATH:/data/gpu_env_common/env/bin/ffmpeg/bin
export PYTHONPATH=$PWD:$PWD/ref/music_remover/demucs:$PWD/ref/so_vits_svc:$PWD/ref/split_dirty_frame
"""
import os
import shutil
import logging
import multiprocessing as mp
from online.inference_one import *
from online.common import *
gs_actw_err_code_download_err = 10001
gs_actw_err_code_trans_err = 10002
gs_actw_err_code_upload_err = 10003
gs_state_default = 0
gs_state_use = 1
gs_state_finish = 2
GS_REGION = "ap-singapore"
GS_BUCKET_NAME = "starmaker-sg-1256122840"
# GS_COSCMD = "/bin/coscmd"
-GS_COSCMD = "coscmd"
+GS_COSCMD = "/data/gpu_env_common/env/anaconda3/envs/music_remover_env/bin/coscmd"
GS_RES_DIR = "/data/gpu_env_common/res"
GS_CONFIG_PATH = os.path.join(GS_RES_DIR, ".online_cos.conf")
def exec_cmd(cmd):
ret = os.system(cmd)
if ret != 0:
return False
return True
def exec_cmd_and_result(cmd):
r = os.popen(cmd)
text = r.read()
r.close()
return text
def upload_file2cos(key, file_path, region=GS_REGION, bucket_name=GS_BUCKET_NAME):
"""
将文件上传到cos
:param key: 桶上的具体地址
:param file_path: 本地文件地址
:param region: 区域
:param bucket_name: 桶地址
:return:
"""
cmd = "{} -c {} -r {} -b {} upload {} {}".format(GS_COSCMD, GS_CONFIG_PATH, region, bucket_name, file_path, key)
print(cmd)
if exec_cmd(cmd):
cmd = "{} -c {} -r {} -b {} info {}".format(GS_COSCMD, GS_CONFIG_PATH, region, bucket_name, key) \
+ "| grep Content-Length |awk \'{print $2}\'"
res_str = exec_cmd_and_result(cmd)
logging.info("{},res={}".format(key, res_str))
size = float(res_str)
if size > 0:
return True
return False
return False
def post_process_err_callback(msg):
print("ERROR|post_process|task_error_callback:", msg)
def effect(queue, finish_queue):
"""
1. 添加音效
2. 混音
3. 上传到服务端
:return:
"""
inst = SongCoverInference()
while True:
logging.info("effect start get...")
data = queue.get()
song_id, work_dir, svc_file, gender = data
logging.info("effect:{},{},{},{}".format(song_id, work_dir, svc_file, gender))
err, effect_file = inst.effect(song_id, work_dir, svc_file)
msg = [song_id, err, svc_file, effect_file, gender]
logging.info("effect,finish:cid={},state={},svc_file={},effect_file={},gender={}". \
format(song_id, err, svc_file, effect_file, gender))
finish_queue.put(msg)
class AutoCoverToolWorker:
def __init__(self):
self.base_dir = "/tmp"
self.work_dir = ""
self.inst = SongCoverInference()
def update_state(self, song_id, state):
sql = "update svc_queue_table set state={},update_time={} where song_id = {}". \
format(state, int(time.time()), song_id)
banned_user_map['db'] = "av_db"
update_db(sql, banned_user_map)
def get_one_data(self):
sql = "select song_id, url from svc_queue_table where state = 0 and song_src=1 order by create_time desc limit 1"
banned_user_map["db"] = "av_db"
data = get_data_by_mysql(sql, banned_user_map)
if len(data) == 0:
return None, None
song_id, song_url = data[0]
if song_id != "":
self.update_state(song_id, gs_state_use)
return str(song_id), song_url
def pre_process(self, work_dir, song_url):
"""
创建文件夹,下载数据
:return:
"""
ext = str(song_url).split(".")[-1]
dst_file = "{}/src_origin.{}".format(work_dir, ext)
cmd = "wget {} -O {}".format(song_url, dst_file)
print(cmd)
os.system(cmd)
if not os.path.exists(dst_file):
return gs_actw_err_code_download_err
dst_mp3_file = "{}/src.mp3".format(work_dir)
cmd = "ffmpeg -i {} -ar 44100 -ac 2 -y {} ".format(dst_file, dst_mp3_file)
os.system(cmd)
if not os.path.exists(dst_mp3_file):
return gs_actw_err_code_trans_err
return gs_err_code_success
def post_process(self, msg):
song_id, err, svc_file, effect_file, gender = msg
work_dir = os.path.join(self.base_dir, str(song_id))
if err != gs_err_code_success:
self.update_state(song_id, -err)
return
# 替换和混音
err, mix_path_mp3 = self.inst.mix(song_id, work_dir, svc_file, effect_file)
logging.info(
"post_process:song_id={},work_dir={},svc_file={},gender={}".format(song_id, work_dir, svc_file, gender))
svc_url = None
state = gs_state_finish
if err != gs_err_code_success:
state = -err
else:
# 上传到cos
mix_name = os.path.basename(mix_path_mp3)
key = "av_res/svc_res/{}".format(mix_name)
if not upload_file2cos(key, mix_path_mp3):
state = -err
else:
state = gs_state_finish
svc_url = key
logging.info("upload_file2cos:song_id={},key={},mix_path_mp3={}".format(song_id, key, mix_path_mp3))
# 更新数据库
if state != gs_state_finish:
self.update_state(song_id, state)
return
sql = "update svc_queue_table set state={},update_time={},svc_url=\"{}\",gender={} where song_id = {}". \
format(gs_state_finish, int(time.time()), svc_url, gender, song_id)
logging.info("post_process:song_id={},sql={}".format(song_id, sql))
banned_user_map['db'] = "av_db"
update_db(sql, banned_user_map)
def process(self):
logging.info("start_process....")
worker_num = 4
worker_queue = mp.Manager().Queue(maxsize=int(worker_num * 1.5))
finish_queue = mp.Manager().Queue(maxsize=int(worker_num * 1.5))
pool = mp.Pool(processes=worker_num)
for i in range(worker_num):
pool.apply_async(effect,
args=(worker_queue, finish_queue),
error_callback=post_process_err_callback)
while True:
# 将堆积的内容处理一遍
while finish_queue.qsize() > 0:
msg = finish_queue.get(timeout=1)
self.post_process(msg)
song_id, err, svc_file, effect_file, gender = msg
work_dir = os.path.join(self.base_dir, str(song_id))
logging.info("clear = song_id={},work_dir={}".format(song_id, work_dir))
shutil.rmtree(work_dir)
song_id, song_url = self.get_one_data()
logging.info("\n\nget_one_data = {},{}".format(song_id, song_url))
if song_id is None:
time.sleep(5)
continue
# 创建空间
work_dir = os.path.join(self.base_dir, str(song_id))
if os.path.exists(work_dir):
shutil.rmtree(work_dir)
os.makedirs(work_dir)
logging.info("song_id={},work_dir={},finish".format(song_id, work_dir))
# 预处理
err = self.pre_process(work_dir, song_url)
if err != gs_err_code_success:
self.update_state(song_id, -err)
shutil.rmtree(work_dir)
continue
logging.info("song_id={},work_dir={},pre_process".format(song_id, work_dir))
# 获取svc数据
err, svc_file = self.inst.generate_svc_file(song_id, work_dir)
if err != gs_err_code_success:
self.update_state(song_id, -err)
shutil.rmtree(work_dir)
continue
logging.info("song_id={},work_dir={},generate_svc_file".format(song_id, work_dir))
# 做音效处理的异步代码
gender = self.inst.get_gender(svc_file)
worker_queue.put([song_id, work_dir, svc_file, gender])
logging.info("song_id={},work_dir={},svc_file={},gender={}".format(song_id, work_dir, svc_file, gender))
pool.close()
pool.join()
if __name__ == '__main__':
actw = AutoCoverToolWorker()
actw.process()

File Metadata

Mime Type
text/x-diff
Expires
Sun, Jan 12, 08:31 (1 d, 10 h)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
1347173
Default Alt Text
(8 KB)

Event Timeline