Page Menu
Home
Phabricator
Search
Configure Global Search
Log In
Files
F4844340
No One
Temporary
Actions
View File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Flag For Later
Size
38 KB
Subscribers
None
View Options
diff --git a/AIMeiSheng/check_crash.py b/AIMeiSheng/check_crash.py
new file mode 100644
index 0000000..e69de29
diff --git a/AIMeiSheng/docker_demo/common.py b/AIMeiSheng/docker_demo/common.py
index 6d3d716..49b3628 100644
--- a/AIMeiSheng/docker_demo/common.py
+++ b/AIMeiSheng/docker_demo/common.py
@@ -1,128 +1,130 @@
import os
import sys
import time
# import logging
import urllib, urllib.request
# 测试/正式环境
gs_prod = True
# if len(sys.argv) > 1 and sys.argv[1] == "prod":
# gs_prod = True
# print(gs_prod)
gs_tmp_dir = "/data/ai_meisheng_tmp"
gs_model_dir = "/data/ai_meisheng_models"
gs_resource_cache_dir = "/tmp/ai_meisheng_resource_cache"
gs_embed_model_path = os.path.join(gs_model_dir, "RawNet3/models/weights/model.pt")
gs_svc_model_path = os.path.join(gs_model_dir,
"weights/xusong_v2_org_version_alldata_embed_spkenx200x_double_e14_s90706.pth")
gs_hubert_model_path = os.path.join(gs_model_dir, "hubert.pt")
gs_rmvpe_model_path = os.path.join(gs_model_dir, "rmvpe.pt")
gs_embed_model_spk_path = os.path.join(gs_model_dir, "SpeakerEncoder/pretrained_model/best_model.pth.tar")
gs_embed_config_spk_path = os.path.join(gs_model_dir, "SpeakerEncoder/pretrained_model/config.json")
# errcode
gs_err_code_success = 0
gs_err_code_download_vocal = 100
gs_err_code_download_svc_url = 101
gs_err_code_svc_process = 102
gs_err_code_transcode = 103
gs_err_code_volume_adjust = 104
gs_err_code_upload = 105
gs_err_code_params = 106
gs_err_code_pending = 107
gs_err_code_target_silence = 108
gs_err_code_too_many_connections = 429
gs_err_code_gender_classify = 430
+gs_err_code_vocal_ratio = 431 #人声占比
+
gs_redis_conf = {
"host": "av-credis.starmaker.co",
"port": 6379,
"pwd": "lKoWEhz%jxTO",
}
# gs_server_redis_conf = {
# "producer": "dev_ai_meisheng_producer", # 输入的队列
# "ai_meisheng_key_prefix": "dev_ai_meisheng_key_", # 存储结果情况
# }
gs_server_redis_conf = {
"producer": "test_ai_meisheng_producer", # 输入的队列
"ai_meisheng_key_prefix": "test_ai_meisheng_key_", # 存储结果情况
}
if gs_prod:
gs_server_redis_conf = {
"producer": "ai_meisheng_producer", # 输入的队列
"ai_meisheng_key_prefix": "ai_meisheng_key_", # 存储结果情况
}
gs_feishu_conf = {
"url": "http://sg-prod-songbook-webmp-1:8000/api/feishu/people",
"users": [
"18810833785", # 杨建利
"17778007843", # 王健军
"18612496315", # 郭子豪
"18600542290" # 方兵晓
]
}
def download2disk(url, dst_path):
try:
urllib.request.urlretrieve(url, dst_path)
return os.path.exists(dst_path)
except Exception as ex:
print(f"download url={url} error", ex)
return False
def exec_cmd(cmd):
# gs_logger.info(cmd)
print(cmd)
ret = os.system(cmd)
if ret != 0:
return False
return True
def exec_cmd_and_result(cmd):
r = os.popen(cmd)
text = r.read()
r.close()
return text
def upload_file2cos(key, file_path, region='ap-singapore', bucket_name='av-audit-sync-sg-1256122840'):
"""
将文件上传到cos
:param key: 桶上的具体地址
:param file_path: 本地文件地址
:param region: 区域
:param bucket_name: 桶地址
:return:
"""
gs_coscmd = "coscmd"
gs_coscmd_conf = "~/.cos.conf"
cmd = "{} -c {} -r {} -b {} upload {} {}".format(gs_coscmd, gs_coscmd_conf, region, bucket_name, file_path, key)
if exec_cmd(cmd):
cmd = "{} -c {} -r {} -b {} info {}".format(gs_coscmd, gs_coscmd_conf, region, bucket_name, key) \
+ "| grep Content-Length |awk \'{print $2}\'"
res_str = exec_cmd_and_result(cmd)
# logging.info("{},res={}".format(key, res_str))
size = float(res_str)
if size > 0:
return True
return False
return False
def check_input(input_data):
key_list = ["record_song_url", "target_url", "start", "end", "vocal_loudness", "female_recording_url",
"male_recording_url"]
for key in key_list:
if key not in input_data.keys():
return False
return True
diff --git a/AIMeiSheng/docker_demo/offline_server.py b/AIMeiSheng/docker_demo/offline_server.py
index f92592e..73f4c18 100644
--- a/AIMeiSheng/docker_demo/offline_server.py
+++ b/AIMeiSheng/docker_demo/offline_server.py
@@ -1,166 +1,170 @@
# -*- coding: UTF-8 -*-
"""
离线处理: 使用redis进行交互,从redis中获取数据资源,在将结果写入到redis
"""
import os
import sys
import time
import json
import socket
import hashlib
from redis_helper import RedisHelper
from cos_helper import CosHelper
from common import *
import logging
from feishu_helper import feishu_send
from svc_online import GSWorkerAttr, SVCOnline, volume_adjustment, svc_offline_logger
sys.path.append(os.path.dirname(__file__))
sys.path.append(os.path.join(os.path.dirname(__file__), "../"))
def download_data(worker_attr):
if os.path.exists(worker_attr.vocal_path):
os.unlink(worker_attr.vocal_path)
st = time.time()
if not download2disk(worker_attr.vocal_url, worker_attr.vocal_path):
return gs_err_code_download_vocal
svc_offline_logger.info(f"download vocal_url={worker_attr.vocal_url} sp = {time.time() - st}")
# download svc_source_url
if not os.path.exists(worker_attr.female_svc_source_path):
st = time.time()
if not download2disk(worker_attr.female_svc_source_url, worker_attr.female_svc_source_path):
return gs_err_code_download_svc_url
svc_offline_logger.info(f"download female_url={worker_attr.female_svc_source_url} sp = {time.time() - st}")
# download svc_source_url
if not os.path.exists(worker_attr.male_svc_source_path):
st = time.time()
if not download2disk(worker_attr.male_svc_source_url, worker_attr.male_svc_source_path):
return gs_err_code_download_svc_url
svc_offline_logger.info(f"download male_url={worker_attr.male_svc_source_url} sp = {time.time() - st}")
return gs_err_code_success
def transcode(wav_path, dst_path):
st = time.time()
cmd = f"ffmpeg -i {wav_path} -ar 44100 -ac 1 -b:a 64k -y {dst_path} -loglevel fatal"
exec_cmd(cmd)
svc_offline_logger.info(f"transcode cmd={cmd}, sp = {time.time() - st}")
return os.path.exists(dst_path)
class OfflineServer:
def __init__(self, redis_conf, server_conf, update_redis=False):
st = time.time()
self.redis_helper = RedisHelper(redis_conf)
self.cos_helper = CosHelper()
self.svc_online = SVCOnline()
self.server_conf = server_conf
self.distinct_key = server_conf["ai_meisheng_key_prefix"]
self.update_redis = update_redis
svc_offline_logger.info(f"config={redis_conf}---server_conf={self.server_conf}")
svc_offline_logger.info(f"offline init finish sp={time.time() - st}")
def exists(self):
return self.redis_helper.exists(self.distinct_key)
def update_result(self, errcode, schedule, gender, target_song_url):
msg = {
"status": errcode,
"schedule": schedule,
"gender": gender,
"target_song_url": target_song_url,
}
# 结果保存15min
if self.update_redis:
self.redis_helper.set(self.distinct_key, json.dumps(msg))
self.redis_helper.expire(self.distinct_key, 60 * 10)
def process_one(self, worker_attr):
self.distinct_key = self.server_conf["ai_meisheng_key_prefix"] + worker_attr.distinct_id
svc_offline_logger.info(f"{worker_attr.log_info_name()}, start download ...")
err = download_data(worker_attr)
if err != gs_err_code_success:
self.update_result(err, 100, "unknown", worker_attr.target_url)
return err, None, None
self.update_result(err, 35, "unknown", worker_attr.target_url)
svc_offline_logger.info(f"{worker_attr.log_info_name()}, start process ...")
gender, err_code = self.svc_online.process(worker_attr)
+ if err_code == gs_err_code_gender_classify:
+ self.update_result(gs_err_code_gender_classify, 100, gender, worker_attr.target_url)
+ return gs_err_code_gender_classify, None, None
if err_code == gs_err_code_target_silence: # unvoice err
+ self.update_result(gs_err_code_target_silence, 100, gender, worker_attr.target_url)
return gs_err_code_target_silence, None, None
if not os.path.exists(worker_attr.target_wav_path):
self.update_result(gs_err_code_svc_process, 100, gender, worker_attr.target_url)
return gs_err_code_svc_process, None, None
self.update_result(err, 85, gender, worker_attr.target_url)
# 音量拉伸到指定响度
svc_offline_logger.info(f"{worker_attr.log_info_name()}, start volume_adjustment ...")
volume_adjustment(worker_attr.target_wav_path, worker_attr.target_loudness, worker_attr.target_wav_ad_path)
if not os.path.exists(worker_attr.target_wav_ad_path):
self.update_result(gs_err_code_volume_adjust, 100, gender, worker_attr.target_url)
return gs_err_code_volume_adjust, None, None
self.update_result(err, 90, gender, worker_attr.target_url)
# transcode
svc_offline_logger.info(f"{worker_attr.log_info_name()}, start transcode ...")
if not transcode(worker_attr.target_wav_ad_path, worker_attr.target_path):
self.update_result(gs_err_code_transcode, 100, gender, worker_attr.target_url)
return gs_err_code_transcode, None, None
self.update_result(err, 95, gender, worker_attr.target_url)
# upload
svc_offline_logger.info(f"{worker_attr.log_info_name()}, start upload_file2cos ...")
st = time.time()
if not self.cos_helper.upload_by_url(worker_attr.target_path, worker_attr.target_url):
self.update_result(gs_err_code_upload, 100, gender, worker_attr.target_url)
return gs_err_code_upload, None, None
self.update_result(gs_err_code_success, 100, gender, worker_attr.target_url)
svc_offline_logger.info(
f"{worker_attr.log_info_name()} upload {worker_attr.target_url} sp = {time.time() - st}")
return gs_err_code_success, worker_attr.target_url, gender
def start_signal(self):
"""
程序启动,发送消息到飞书
:return:
"""
host_name = socket.gethostname()
host_ip = socket.gethostbyname(host_name)
msg = f"{host_name}({host_ip}) offline server start"
feishu_send(gs_feishu_conf["url"], msg, gs_feishu_conf["users"])
def process(self):
self.start_signal()
while True:
data = self.redis_helper.rpop(self.server_conf["producer"])
if data is None:
time.sleep(1)
continue
data = json.loads(data)
if not check_input(data):
svc_offline_logger.error(f"input data error={data}")
continue
worker_attr = GSWorkerAttr(data)
self.distinct_key = self.server_conf["ai_meisheng_key_prefix"] + worker_attr.distinct_id
if not self.exists():
svc_offline_logger.warning(f"input {data}, timeout abandon ....")
worker_attr.rm_cache()
continue
st = time.time()
errcode, target_path, gender = self.process_one(worker_attr)
self.update_result(errcode, 100, gender, target_path)
svc_offline_logger.info(f"{worker_attr.log_info_name()} finish errcode={errcode} sp = {time.time() - st}")
worker_attr.rm_cache()
if __name__ == '__main__':
offline_server = OfflineServer(gs_redis_conf, gs_server_redis_conf, True)
offline_server.process()
diff --git a/AIMeiSheng/docker_demo/svc_online.py b/AIMeiSheng/docker_demo/svc_online.py
index 1606703..421b5f6 100644
--- a/AIMeiSheng/docker_demo/svc_online.py
+++ b/AIMeiSheng/docker_demo/svc_online.py
@@ -1,197 +1,207 @@
# -*- coding: UTF-8 -*-
"""
SVC的核心处理逻辑
"""
import os
import time
import socket
import shutil
import hashlib
from AIMeiSheng.meisheng_svc_final import load_model, process_svc_online
from AIMeiSheng.cos_similar_ui_zoom import cos_similar
from AIMeiSheng.meisheng_env_preparex import meisheng_env_prepare
from AIMeiSheng.voice_classification.online.voice_class_online_fang import VoiceClass, download_volume_balanced
from AIMeiSheng.docker_demo.common import *
import logging
hostname = socket.gethostname()
log_file_name = f"{os.path.dirname(os.path.abspath(__file__))}/av_meisheng_{hostname}.log"
# 设置logger
svc_offline_logger = logging.getLogger("svc_offline")
file_handler = logging.FileHandler(log_file_name)
file_handler.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s', datefmt='%Y-%m-%d %I:%M:%S')
file_handler.setFormatter(formatter)
# if gs_prod:
# svc_offline_logger.addHandler(file_handler)
if os.path.exists(gs_tmp_dir):
shutil.rmtree(gs_tmp_dir)
os.makedirs(gs_model_dir, exist_ok=True)
os.makedirs(gs_resource_cache_dir, exist_ok=True)
# 预设参数
gs_gender_models_url = "https://av-audit-sync-sg-1256122840.cos.ap-singapore.myqcloud.com/hub/voice_classification/models.zip"
gs_volume_bin_url = "https://av-audit-sync-sg-1256122840.cos.ap-singapore.myqcloud.com/dataset/AIMeiSheng/ebur128_tool/v1/ebur128_tool"
class GSWorkerAttr:
def __init__(self, input_data):
# 取出输入资源
vocal_url = input_data["record_song_url"]
target_url = input_data["target_url"]
start = input_data["start"] # 单位是ms
end = input_data["end"] # 单位是ms
vocal_loudness = input_data["vocal_loudness"]
female_recording_url = input_data["female_recording_url"]
male_recording_url = input_data["male_recording_url"]
self.distinct_id = hashlib.md5(vocal_url.encode()).hexdigest()
self.tmp_dir = os.path.join(gs_tmp_dir, self.distinct_id)
if os.path.exists(self.tmp_dir):
shutil.rmtree(self.tmp_dir)
os.makedirs(self.tmp_dir)
self.vocal_url = vocal_url
self.target_url = target_url
ext = vocal_url.split(".")[-1]
self.vocal_path = os.path.join(self.tmp_dir, self.distinct_id + f"_in.{ext}")
self.target_wav_path = os.path.join(self.tmp_dir, self.distinct_id + "_out.wav")
self.target_wav_ad_path = os.path.join(self.tmp_dir, self.distinct_id + "_out_ad.wav")
self.target_path = os.path.join(self.tmp_dir, self.distinct_id + "_out.m4a")
self.female_svc_source_url = female_recording_url
self.male_svc_source_url = male_recording_url
ext = female_recording_url.split(".")[-1]
self.female_svc_source_path = os.path.join(self.tmp_dir, self.distinct_id + f"_female.{ext}")
ext = male_recording_url.split(".")[-1]
self.male_svc_source_path = os.path.join(self.tmp_dir, self.distinct_id + f"_male.{ext}")
# self.female_svc_source_path = os.path.join(gs_resource_cache_dir,
# hashlib.md5(female_recording_url.encode()).hexdigest() + "." + ext)
# ext = male_recording_url.split(".")[-1]
# self.male_svc_source_path = os.path.join(gs_resource_cache_dir,
# hashlib.md5(male_recording_url.encode()).hexdigest() + "." + ext)
self.st_tm = start
self.ed_tm = end
self.target_loudness = vocal_loudness
def log_info_name(self):
return f"d_id={self.distinct_id}, vocal_url={self.vocal_url}"
def rm_cache(self):
if os.path.exists(self.tmp_dir):
shutil.rmtree(self.tmp_dir)
def init_gender_model():
"""
下载模型
:return:
"""
dst_model_dir = os.path.join(gs_model_dir, "voice_classification")
if not os.path.exists(dst_model_dir):
dst_zip_path = os.path.join(gs_model_dir, "models.zip")
if not download2disk(gs_gender_models_url, dst_zip_path):
svc_offline_logger.fatal(f"download gender_model err={gs_gender_models_url}")
cmd = f"cd {gs_model_dir}; unzip {dst_zip_path}; mv models voice_classification; rm -f {dst_zip_path}"
os.system(cmd)
if not os.path.exists(dst_model_dir):
svc_offline_logger.fatal(f"unzip {dst_zip_path} err")
music_voice_pure_model = os.path.join(dst_model_dir, "voice_005_rec_v5.pth")
music_voice_no_pure_model = os.path.join(dst_model_dir, "voice_10_v5.pth")
gender_pure_model = os.path.join(dst_model_dir, "gender_8k_ratev5_v6_adam.pth")
gender_no_pure_model = os.path.join(dst_model_dir, "gender_8k_v6_adam.pth")
vc = VoiceClass(music_voice_pure_model, music_voice_no_pure_model, gender_pure_model, gender_no_pure_model)
return vc
def init_svc_model():
meisheng_env_prepare(logging, gs_model_dir)
embed_model, hubert_model = load_model()
cs_sim = cos_similar()
return embed_model, hubert_model, cs_sim
def download_volume_adjustment():
"""
下载音量调整工具
:return:
"""
volume_bin_path = os.path.join(gs_model_dir, "ebur128_tool")
if not os.path.exists(volume_bin_path):
if not download2disk(gs_volume_bin_url, volume_bin_path):
svc_offline_logger.fatal(f"download volume_bin err={gs_volume_bin_url}")
os.system(f"chmod +x {volume_bin_path}")
def volume_adjustment(wav_path, target_loudness, out_path):
"""
音量调整
:param wav_path:
:param target_loudness:
:param out_path:
:return:
"""
volume_bin_path = os.path.join(gs_model_dir, "ebur128_tool")
cmd = f"{volume_bin_path} {wav_path} {target_loudness} {out_path}"
os.system(cmd)
class SVCOnline:
def __init__(self):
st = time.time()
self.gender_model = init_gender_model()
self.embed_model, self.hubert_model, self.cs_sim = init_svc_model()
download_volume_adjustment()
download_volume_balanced()
svc_offline_logger.info(f"svc init finished, sp = {time.time() - st}")
def gender_process(self, worker_attr):
st = time.time()
gender, female_rate, is_pure = self.gender_model.process(worker_attr.vocal_path)
+
svc_offline_logger.info(
f"{worker_attr.vocal_url}, gender={gender}, female_rate={female_rate}, is_pure={is_pure}, "
f"gender_process sp = {time.time() - st}")
if gender == 0:
gender = 'female'
elif gender == 1:
gender = 'male'
elif female_rate == None:
gender = 'male'
return gender, gs_err_code_gender_classify
elif female_rate > 0.5:
gender = 'female'
else:
gender = 'male'
+ if gender == 'female':
+ if self.gender_model.vocal_ratio < 0.5:
+ print(f"@@@ vocal_ratio: {self.gender_model.vocal_ratio}, gender : {gender}, gs_err_code_vocal_ratio : {gs_err_code_vocal_ratio}")
+ return gender, gs_err_code_vocal_ratio
+ else:
+ if self.gender_model.vocal_ratio < 0.6:
+ print(f"@@@ vocal_ratio: {self.gender_model.vocal_ratio}, gender : {gender}, gs_err_code_vocal_ratio : {gs_err_code_vocal_ratio}")
+ return gender, gs_err_code_vocal_ratio
+
svc_offline_logger.info(f"{worker_attr.vocal_url}, modified gender={gender}")
# err = gs_err_code_success
# if female_rate == -1:
# err = gs_err_code_target_silence
return gender, gs_err_code_success
def process(self, worker_attr):
gender, err = self.gender_process(worker_attr)
if err != gs_err_code_success:
return gender, err
song_path = worker_attr.female_svc_source_path
if gender == "male":
song_path = worker_attr.male_svc_source_path
params = {'gender': gender, 'tst': worker_attr.st_tm, "tnd": worker_attr.ed_tm, 'delay': 0, 'song_path': None}
st = time.time()
err_code = process_svc_online(song_path, worker_attr.vocal_path, worker_attr.target_wav_path, self.embed_model,
self.hubert_model, self.cs_sim, params)
svc_offline_logger.info(f"{worker_attr.vocal_url}, err_code={err_code} process svc sp = {time.time() - st}")
return gender, err_code
diff --git a/AIMeiSheng/voice_classification/online/voice_class_online_fang.py b/AIMeiSheng/voice_classification/online/voice_class_online_fang.py
index d792db0..a74173d 100644
--- a/AIMeiSheng/voice_classification/online/voice_class_online_fang.py
+++ b/AIMeiSheng/voice_classification/online/voice_class_online_fang.py
@@ -1,444 +1,446 @@
"""
男女声分类在线工具
1 转码为16bit单声道
2 均衡化
3 模型分类
"""
import os
import sys
import librosa
import shutil
import logging
import time
import torch.nn.functional as F
import numpy as np
from model import *
# from common import bind_kernel
logging.basicConfig(level=logging.INFO)
os.environ["LRU_CACHE_CAPACITY"] = "1"
# torch.set_num_threads(1)
# bind_kernel(1)
"""
临时用一下,全局使用的变量
"""
transcode_time = 0
vb_time = 0
mfcc_time = 0
predict_time = 0
"""
错误码
"""
ERR_CODE_SUCCESS = 0 # 处理成功
ERR_CODE_NO_FILE = -1 # 文件不存在
ERR_CODE_TRANSCODE = -2 # 转码失败
ERR_CODE_VOLUME_BALANCED = -3 # 均衡化失败
ERR_CODE_FEATURE_TOO_SHORT = -4 # 特征文件太短
"""
常量
"""
FRAME_LEN = 128
MFCC_LEN = 80
gs_bin_url = "https://av-audit-sync-sg-1256122840.cos.ap-singapore.myqcloud.com/hub/voice_classification/bin/bin.zip"
EBUR128_BIN = "/tmp/voice_class_bin/standard_audio_no_cut"
# EBUR128_BIN = "/Users/yangjianli/linux/opt/soft/bin/standard_audio_no_cut"
GENDER_FEMALE = 0
GENDER_MALE = 1
GENDER_OTHER = 2
"""
通用函数
"""
def exec_cmd(cmd):
ret = os.system(cmd)
if ret != 0:
return False
return True
"""
业务需要的函数
"""
def get_one_mfcc(file_url):
st = time.time()
data, sr = librosa.load(file_url, sr=16000)
if len(data) < 512:
return []
mfcc = librosa.feature.mfcc(y=data, sr=sr, n_fft=512, hop_length=256, n_mfcc=MFCC_LEN)
mfcc = mfcc.transpose()
print("get_one_mfcc:spend_time={}".format(time.time() - st))
global mfcc_time
mfcc_time += time.time() - st
return mfcc
def download_volume_balanced():
import urllib.request
if not os.path.exists(EBUR128_BIN):
dst_path = "/tmp/bin.zip"
urllib.request.urlretrieve(gs_bin_url, dst_path)
if not os.path.exists(dst_path):
print(f"download dst_path={gs_bin_url} err!")
exit(-1)
dirname = os.path.dirname(dst_path)
cmd = f"cd {dirname}; unzip bin.zip; rm -f bin.zip; mv bin voice_class_bin"
os.system(cmd)
if not os.path.exists(EBUR128_BIN):
print(f"exec {cmd} err!")
exit(-1)
def volume_balanced(src, dst):
st = time.time()
download_volume_balanced()
cmd = "{} {} {}".format(EBUR128_BIN, src, dst)
logging.info(cmd)
exec_cmd(cmd)
if not os.path.exists(dst):
logging.error("volume_balanced:cmd={}".format(cmd))
print("volume_balanced:spend_time={}".format(time.time() - st))
global vb_time
vb_time += time.time() - st
return os.path.exists(dst)
def transcode(src, dst):
st = time.time()
cmd = "ffmpeg -loglevel quiet -i {} -ar 16000 -ac 1 {}".format(src, dst)
logging.info(cmd)
exec_cmd(cmd)
if not os.path.exists(dst):
logging.error("transcode:cmd={}".format(cmd))
print("transcode:spend_time={}".format(time.time() - st))
global transcode_time
transcode_time += time.time() - st
return os.path.exists(dst)
class VoiceClass:
def __init__(self, music_voice_pure_model, music_voice_no_pure_model, gender_pure_model, gender_no_pure_model):
"""
四个模型
:param music_voice_pure_model: 分辨纯净人声/其他
:param music_voice_no_pure_model: 分辨有人声/其他
:param gender_pure_model: 纯净人声分辨男女
:param gender_no_pure_model: 有人声分辨男女
"""
st = time.time()
self.device = "cpu"
self.batch_size = 256
self.music_voice_pure_model = load_model(MusicVoiceV5Model, music_voice_pure_model, self.device)
self.music_voice_no_pure_model = load_model(MusicVoiceV5Model, music_voice_no_pure_model, self.device)
self.gender_pure_model = load_model(MobileNetV2Gender, gender_pure_model, self.device)
self.gender_no_pure_model = load_model(MobileNetV2Gender, gender_no_pure_model, self.device)
logging.info("load model ok ! spend_time={}".format(time.time() - st))
+ self.vocal_ratio = 1
def batch_predict(self, model, features):
st = time.time()
scores = []
with torch.no_grad():
for i in range(0, len(features), self.batch_size):
cur_data = features[i:i + self.batch_size].to(self.device)
predicts = model(cur_data)
predicts_score = F.softmax(predicts, dim=1)
scores.extend(predicts_score.cpu().numpy())
ret = np.array(scores)
global predict_time
predict_time += time.time() - st
return ret
def predict_pure(self, filename, features):
scores = self.batch_predict(self.music_voice_pure_model, features)
new_features = []
for idx, score in enumerate(scores):
if score[0] > 0.5: # 非人声
continue
new_features.append(features[idx].numpy())
# 人声段太少,不能进行处理
# 参数可以改
new_feature_len = len(new_features)
new_feature_rate = len(new_features) / len(features)
+ self.vocal_ratio = new_feature_rate
# 修改人声占比低一些
if new_feature_len < 4 or new_feature_rate < 0.05:
logging.warning(
"filename={}|predict_pure|other|len={}|rate={}".format(filename, new_feature_len, new_feature_rate)
)
return GENDER_OTHER, -1
new_features = torch.from_numpy(np.array(new_features))
scores = self.batch_predict(self.gender_pure_model, new_features)
f_avg = sum(scores[:, 0]) / len(scores)
m_avg = sum(scores[:, 1]) / len(scores)
female_rate = f_avg / (f_avg + m_avg)
if female_rate > 0.65:
return GENDER_FEMALE, female_rate
if female_rate < 0.12:
return GENDER_MALE, female_rate
logging.warning(
"filename={}|predict_pure|other|len={}|rate={}".format(filename, new_feature_len, new_feature_rate)
)
return GENDER_OTHER, female_rate
def predict_no_pure(self, filename, features):
scores = self.batch_predict(self.music_voice_no_pure_model, features)
new_features = []
for idx, score in enumerate(scores):
if score[0] > 0.5: # 非人声
continue
new_features.append(features[idx].numpy())
# 人声段太少,不能进行处理
# 参数可以改
new_feature_len = len(new_features)
new_feature_rate = len(new_features) / len(features)
# 修改人声占比低一些
if new_feature_len < 4 or new_feature_rate < 0.05:
logging.warning(
"filename={}|predict_no_pure|other|len={}|rate={}".format(filename, new_feature_len, new_feature_rate)
)
return GENDER_OTHER, -1
new_features = torch.from_numpy(np.array(new_features))
scores = self.batch_predict(self.gender_no_pure_model, new_features)
f_avg = sum(scores[:, 0]) / len(scores)
m_avg = sum(scores[:, 1]) / len(scores)
female_rate = f_avg / (f_avg + m_avg)
if female_rate > 0.75:
return GENDER_FEMALE, female_rate
if female_rate < 0.1:
return GENDER_MALE, female_rate
logging.warning(
"filename={}|predict_no_pure|other|len={}|rate={}".format(filename, new_feature_len, new_feature_rate)
)
return GENDER_OTHER, female_rate
def predict(self, filename, features):
st = time.time()
new_features = []
for i in range(FRAME_LEN, len(features), FRAME_LEN):
new_features.append(features[i - FRAME_LEN: i])
new_features = torch.from_numpy(np.array(new_features))
gender, rate = self.predict_pure(filename, new_features)
if gender == GENDER_OTHER:
logging.info("start no pure process...")
gender, rate = self.predict_no_pure(filename, new_features)
return gender, rate, False
print("predict|spend_time={}".format(time.time() - st))
return gender, rate, True
def process_one_logic(self, filename, file_path, cache_dir):
tmp_wav = os.path.join(cache_dir, "tmp.wav")
tmp_vb_wav = os.path.join(cache_dir, "tmp_vb.wav")
if not transcode(file_path, tmp_wav):
return ERR_CODE_TRANSCODE, None, None
if not volume_balanced(tmp_wav, tmp_vb_wav):
return ERR_CODE_VOLUME_BALANCED, None, None
features = get_one_mfcc(tmp_vb_wav)
if len(features) < FRAME_LEN:
logging.error("feature too short|file_path={}".format(file_path))
return ERR_CODE_FEATURE_TOO_SHORT, None, None
return self.predict(filename, features)
def process_one(self, file_path):
base_dir = os.path.dirname(file_path)
filename = os.path.splitext(file_path)[0]
print("filename:", filename)
cache_dir = os.path.join(base_dir, filename + "_cache")
if os.path.exists(cache_dir):
shutil.rmtree(cache_dir)
os.makedirs(cache_dir)
ret = self.process_one_logic(filename, file_path, cache_dir)
shutil.rmtree(cache_dir)
return ret
def process(self, file_path):
gender, female_rate, is_pure = self.process_one(file_path)
logging.info("{}|gender={}|female_rate={}".format(file_path, gender, female_rate))
return gender, female_rate, is_pure
def process_by_feature(self, feature_file):
"""
直接处理特征文件
:param feature_file:
:return:
"""
filename = os.path.splitext(feature_file)[0]
features = np.load(feature_file)
gender, female_rate = self.predict(filename, features)
return gender, female_rate
def test_all_feature():
import glob
base_dir = "/data/datasets/music_voice_dataset_full/feature_online_data_v3"
female = glob.glob(os.path.join(base_dir, "female/*feature.npy"))
male = glob.glob(os.path.join(base_dir, "male/*feature.npy"))
other = glob.glob(os.path.join(base_dir, "other/*feature.npy"))
model_path = "/data/jianli.yang/voice_classification/online/models"
music_voice_pure_model = os.path.join(model_path, "voice_005_rec_v5.pth")
music_voice_no_pure_model = os.path.join(model_path, "voice_10_v5.pth")
gender_pure_model = os.path.join(model_path, "gender_8k_ratev5_v6_adam.pth")
gender_no_pure_model = os.path.join(model_path, "gender_8k_v6_adam.pth")
vc = VoiceClass(music_voice_pure_model, music_voice_no_pure_model, gender_pure_model, gender_no_pure_model)
tot_st = time.time()
ret_map = {
0: {0: 0, 1: 0, 2: 0},
1: {0: 0, 1: 0, 2: 0},
2: {0: 0, 1: 0, 2: 0}
}
for file in female:
st = time.time()
print("------------------------------>>>>>")
gender, female_score = vc.process_by_feature(file)
ret_map[0][gender] += 1
if gender != 0:
print("err:female->{}|{}|{}".format(gender, file, female_score))
print("process|spend_tm=={}".format(time.time() - st))
for file in male:
st = time.time()
print("------------------------------>>>>>")
gender, female_score = vc.process_by_feature(file)
ret_map[1][gender] += 1
if gender != 1:
print("err:male->{}|{}|{}".format(gender, file, female_score))
print("process|spend_tm=={}".format(time.time() - st))
for file in other:
st = time.time()
print("------------------------------>>>>>")
gender, female_score = vc.process_by_feature(file)
ret_map[2][gender] += 1
if gender != 2:
print("err:other->{}|{}|{}".format(gender, file, female_score))
print("process|spend_tm=={}".format(time.time() - st))
global transcode_time, vb_time, mfcc_time, predict_time
print("spend_time:tot={}|transcode={}|vb={}|gen_feature={}|predict={}".format(time.time() - tot_st, transcode_time,
vb_time, mfcc_time, predict_time))
f_f = ret_map[0][0]
f_m = ret_map[0][1]
f_o = ret_map[0][2]
m_f = ret_map[1][0]
m_m = ret_map[1][1]
m_o = ret_map[1][2]
o_f = ret_map[2][0]
o_m = ret_map[2][1]
o_o = ret_map[2][2]
print("ff:{},fm:{},fo:{}".format(f_f, f_m, f_o))
print("mm:{},mf:{},mo:{}".format(m_m, m_f, m_o))
print("om:{},of:{},oo:{}".format(o_m, o_f, o_o))
# 女性准确率和召回率
f_acc = f_f / (f_f + m_f + o_f)
f_recall = f_f / (f_f + f_m + f_o)
# 男性准确率和召回率
m_acc = m_m / (m_m + f_m + o_m)
m_recall = m_m / (m_m + m_f + m_o)
print("female: acc={}|recall={}".format(f_acc, f_recall))
print("male: acc={}|recall={}".format(m_acc, m_recall))
def test_all():
import glob
base_dir = "/data/datasets/music_voice_dataset_full/online_data_v3_top200"
female = glob.glob(os.path.join(base_dir, "female/*mp4"))
male = glob.glob(os.path.join(base_dir, "male/*mp4"))
other = glob.glob(os.path.join(base_dir, "other/*mp4"))
model_path = "/data/jianli.yang/voice_classification/online/models"
music_voice_pure_model = os.path.join(model_path, "voice_005_rec_v5.pth")
music_voice_no_pure_model = os.path.join(model_path, "voice_10_v5.pth")
gender_pure_model = os.path.join(model_path, "gender_8k_ratev5_v6_adam.pth")
gender_no_pure_model = os.path.join(model_path, "gender_8k_v6_adam.pth")
vc = VoiceClass(music_voice_pure_model, music_voice_no_pure_model, gender_pure_model, gender_no_pure_model)
tot_st = time.time()
ret_map = {
0: {0: 0, 1: 0, 2: 0},
1: {0: 0, 1: 0, 2: 0},
2: {0: 0, 1: 0, 2: 0}
}
for file in female:
st = time.time()
print("------------------------------>>>>>")
gender, female_score = vc.process(file)
ret_map[0][gender] += 1
if gender != 0:
print("err:female->{}|{}|{}".format(gender, file, female_score))
print("process|spend_tm=={}".format(time.time() - st))
for file in male:
st = time.time()
print("------------------------------>>>>>")
gender, female_score = vc.process(file)
ret_map[1][gender] += 1
if gender != 1:
print("err:male->{}|{}|{}".format(gender, file, female_score))
print("process|spend_tm=={}".format(time.time() - st))
for file in other:
st = time.time()
print("------------------------------>>>>>")
gender, female_score = vc.process(file)
ret_map[2][gender] += 1
if gender != 2:
print("err:other->{}|{}|{}".format(gender, file, female_score))
print("process|spend_tm=={}".format(time.time() - st))
global transcode_time, vb_time, mfcc_time, predict_time
print("spend_time:tot={}|transcode={}|vb={}|gen_feature={}|predict={}".format(time.time() - tot_st, transcode_time,
vb_time, mfcc_time, predict_time))
f_f = ret_map[0][0]
f_m = ret_map[0][1]
f_o = ret_map[0][2]
m_f = ret_map[1][0]
m_m = ret_map[1][1]
m_o = ret_map[1][2]
o_f = ret_map[2][0]
o_m = ret_map[2][1]
o_o = ret_map[2][2]
print("ff:{},fm:{},fo:{}".format(f_f, f_m, f_o))
print("mm:{},mf:{},mo:{}".format(m_m, m_f, m_o))
print("om:{},of:{},oo:{}".format(o_m, o_f, o_o))
# 女性准确率和召回率
f_acc = f_f / (f_f + m_f + o_f)
f_recall = f_f / (f_f + f_m + f_o)
# 男性准确率和召回率
m_acc = m_m / (m_m + f_m + o_m)
m_recall = m_m / (m_m + m_f + m_o)
print("female: acc={}|recall={}".format(f_acc, f_recall))
print("male: acc={}|recall={}".format(m_acc, m_recall))
if __name__ == "__main__":
# test_all()
# test_all_feature()
model_path = sys.argv[1]
voice_path = sys.argv[2]
music_voice_pure_model = os.path.join(model_path, "voice_005_rec_v5.pth")
music_voice_no_pure_model = os.path.join(model_path, "voice_10_v5.pth")
gender_pure_model = os.path.join(model_path, "gender_8k_ratev5_v6_adam.pth")
gender_no_pure_model = os.path.join(model_path, "gender_8k_v6_adam.pth")
vc = VoiceClass(music_voice_pure_model, music_voice_no_pure_model, gender_pure_model, gender_no_pure_model)
for i in range(0, 1):
st = time.time()
print("------------------------------>>>>>")
gender, female_rate, is_pure = vc.process(voice_path)
print("process|spend_tm=={}".format(time.time() - st))
print("gender:{}, female_rate:{},is_pure:{}".format(gender, female_rate, is_pure))
File Metadata
Details
Attached
Mime Type
text/x-diff
Expires
Sat, Nov 23, 18:14 (1 d, 18 h)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
1325561
Default Alt Text
(38 KB)
Attached To
R350 av_svc
Event Timeline
Log In to Comment