Page Menu
Home
Phabricator
Search
Configure Global Search
Log In
Files
F4880298
No One
Temporary
Actions
View File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Flag For Later
Size
8 KB
Subscribers
None
View Options
diff --git a/AutoCoverTool/online/inference_worker.py b/AutoCoverTool/online/inference_worker.py
index 22d7a3c..50d502b 100644
--- a/AutoCoverTool/online/inference_worker.py
+++ b/AutoCoverTool/online/inference_worker.py
@@ -1,240 +1,240 @@
"""
离线worker
数据库字段要求:
// 其中state的状态
// 0:默认,1:被取走,<0异常情况,2完成
// 超时到一定程度也会被重新放回来
数据库格式:
id,song_id,url,state,svc_url,create_time,update_time,gender
启动时的环境要求:
export PATH=$PATH:/data/gpu_env_common/env/bin/ffmpeg/bin
export PYTHONPATH=$PWD:$PWD/ref/music_remover/demucs:$PWD/ref/so_vits_svc:$PWD/ref/split_dirty_frame
"""
import os
import shutil
import logging
import multiprocessing as mp
from online.inference_one import *
from online.common import *
gs_actw_err_code_download_err = 10001
gs_actw_err_code_trans_err = 10002
gs_actw_err_code_upload_err = 10003
gs_state_default = 0
gs_state_use = 1
gs_state_finish = 2
GS_REGION = "ap-singapore"
GS_BUCKET_NAME = "starmaker-sg-1256122840"
# GS_COSCMD = "/bin/coscmd"
-GS_COSCMD = "coscmd"
+GS_COSCMD = "/data/gpu_env_common/env/anaconda3/envs/music_remover_env/bin/coscmd"
GS_RES_DIR = "/data/gpu_env_common/res"
GS_CONFIG_PATH = os.path.join(GS_RES_DIR, ".online_cos.conf")
def exec_cmd(cmd):
ret = os.system(cmd)
if ret != 0:
return False
return True
def exec_cmd_and_result(cmd):
r = os.popen(cmd)
text = r.read()
r.close()
return text
def upload_file2cos(key, file_path, region=GS_REGION, bucket_name=GS_BUCKET_NAME):
"""
将文件上传到cos
:param key: 桶上的具体地址
:param file_path: 本地文件地址
:param region: 区域
:param bucket_name: 桶地址
:return:
"""
cmd = "{} -c {} -r {} -b {} upload {} {}".format(GS_COSCMD, GS_CONFIG_PATH, region, bucket_name, file_path, key)
print(cmd)
if exec_cmd(cmd):
cmd = "{} -c {} -r {} -b {} info {}".format(GS_COSCMD, GS_CONFIG_PATH, region, bucket_name, key) \
+ "| grep Content-Length |awk \'{print $2}\'"
res_str = exec_cmd_and_result(cmd)
logging.info("{},res={}".format(key, res_str))
size = float(res_str)
if size > 0:
return True
return False
return False
def post_process_err_callback(msg):
print("ERROR|post_process|task_error_callback:", msg)
def effect(queue, finish_queue):
"""
1. 添加音效
2. 混音
3. 上传到服务端
:return:
"""
inst = SongCoverInference()
while True:
logging.info("effect start get...")
data = queue.get()
song_id, work_dir, svc_file, gender = data
logging.info("effect:{},{},{},{}".format(song_id, work_dir, svc_file, gender))
err, effect_file = inst.effect(song_id, work_dir, svc_file)
msg = [song_id, err, svc_file, effect_file, gender]
logging.info("effect,finish:cid={},state={},svc_file={},effect_file={},gender={}". \
format(song_id, err, svc_file, effect_file, gender))
finish_queue.put(msg)
class AutoCoverToolWorker:
def __init__(self):
self.base_dir = "/tmp"
self.work_dir = ""
self.inst = SongCoverInference()
def update_state(self, song_id, state):
sql = "update svc_queue_table set state={},update_time={} where song_id = {}". \
format(state, int(time.time()), song_id)
banned_user_map['db'] = "av_db"
update_db(sql, banned_user_map)
def get_one_data(self):
sql = "select song_id, url from svc_queue_table where state = 0 and song_src=1 order by create_time desc limit 1"
banned_user_map["db"] = "av_db"
data = get_data_by_mysql(sql, banned_user_map)
if len(data) == 0:
return None, None
song_id, song_url = data[0]
if song_id != "":
self.update_state(song_id, gs_state_use)
return str(song_id), song_url
def pre_process(self, work_dir, song_url):
"""
创建文件夹,下载数据
:return:
"""
ext = str(song_url).split(".")[-1]
dst_file = "{}/src_origin.{}".format(work_dir, ext)
cmd = "wget {} -O {}".format(song_url, dst_file)
print(cmd)
os.system(cmd)
if not os.path.exists(dst_file):
return gs_actw_err_code_download_err
dst_mp3_file = "{}/src.mp3".format(work_dir)
cmd = "ffmpeg -i {} -ar 44100 -ac 2 -y {} ".format(dst_file, dst_mp3_file)
os.system(cmd)
if not os.path.exists(dst_mp3_file):
return gs_actw_err_code_trans_err
return gs_err_code_success
def post_process(self, msg):
song_id, err, svc_file, effect_file, gender = msg
work_dir = os.path.join(self.base_dir, str(song_id))
if err != gs_err_code_success:
self.update_state(song_id, -err)
return
# 替换和混音
err, mix_path_mp3 = self.inst.mix(song_id, work_dir, svc_file, effect_file)
logging.info(
"post_process:song_id={},work_dir={},svc_file={},gender={}".format(song_id, work_dir, svc_file, gender))
svc_url = None
state = gs_state_finish
if err != gs_err_code_success:
state = -err
else:
# 上传到cos
mix_name = os.path.basename(mix_path_mp3)
key = "av_res/svc_res/{}".format(mix_name)
if not upload_file2cos(key, mix_path_mp3):
state = -err
else:
state = gs_state_finish
svc_url = key
logging.info("upload_file2cos:song_id={},key={},mix_path_mp3={}".format(song_id, key, mix_path_mp3))
# 更新数据库
if state != gs_state_finish:
self.update_state(song_id, state)
return
sql = "update svc_queue_table set state={},update_time={},svc_url=\"{}\",gender={} where song_id = {}". \
format(gs_state_finish, int(time.time()), svc_url, gender, song_id)
logging.info("post_process:song_id={},sql={}".format(song_id, sql))
banned_user_map['db'] = "av_db"
update_db(sql, banned_user_map)
def process(self):
logging.info("start_process....")
worker_num = 4
worker_queue = mp.Manager().Queue(maxsize=int(worker_num * 1.5))
finish_queue = mp.Manager().Queue(maxsize=int(worker_num * 1.5))
pool = mp.Pool(processes=worker_num)
for i in range(worker_num):
pool.apply_async(effect,
args=(worker_queue, finish_queue),
error_callback=post_process_err_callback)
while True:
# 将堆积的内容处理一遍
while finish_queue.qsize() > 0:
msg = finish_queue.get(timeout=1)
self.post_process(msg)
song_id, err, svc_file, effect_file, gender = msg
work_dir = os.path.join(self.base_dir, str(song_id))
logging.info("clear = song_id={},work_dir={}".format(song_id, work_dir))
shutil.rmtree(work_dir)
song_id, song_url = self.get_one_data()
logging.info("\n\nget_one_data = {},{}".format(song_id, song_url))
if song_id is None:
time.sleep(5)
continue
# 创建空间
work_dir = os.path.join(self.base_dir, str(song_id))
if os.path.exists(work_dir):
shutil.rmtree(work_dir)
os.makedirs(work_dir)
logging.info("song_id={},work_dir={},finish".format(song_id, work_dir))
# 预处理
err = self.pre_process(work_dir, song_url)
if err != gs_err_code_success:
self.update_state(song_id, -err)
shutil.rmtree(work_dir)
continue
logging.info("song_id={},work_dir={},pre_process".format(song_id, work_dir))
# 获取svc数据
err, svc_file = self.inst.generate_svc_file(song_id, work_dir)
if err != gs_err_code_success:
self.update_state(song_id, -err)
shutil.rmtree(work_dir)
continue
logging.info("song_id={},work_dir={},generate_svc_file".format(song_id, work_dir))
# 做音效处理的异步代码
gender = self.inst.get_gender(svc_file)
worker_queue.put([song_id, work_dir, svc_file, gender])
logging.info("song_id={},work_dir={},svc_file={},gender={}".format(song_id, work_dir, svc_file, gender))
pool.close()
pool.join()
if __name__ == '__main__':
actw = AutoCoverToolWorker()
actw.process()
File Metadata
Details
Attached
Mime Type
text/x-diff
Expires
Sun, Jan 12, 08:31 (1 d, 15 h)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
1347173
Default Alt Text
(8 KB)
Attached To
R350 av_svc
Event Timeline
Log In to Comment