Page MenuHomePhabricator

No OneTemporary

diff --git a/AIMeiSheng/check_crash.py b/AIMeiSheng/check_crash.py
new file mode 100644
index 0000000..e69de29
diff --git a/AIMeiSheng/docker_demo/common.py b/AIMeiSheng/docker_demo/common.py
index 6d3d716..49b3628 100644
--- a/AIMeiSheng/docker_demo/common.py
+++ b/AIMeiSheng/docker_demo/common.py
@@ -1,128 +1,130 @@
import os
import sys
import time
# import logging
import urllib, urllib.request
# 测试/正式环境
gs_prod = True
# if len(sys.argv) > 1 and sys.argv[1] == "prod":
# gs_prod = True
# print(gs_prod)
gs_tmp_dir = "/data/ai_meisheng_tmp"
gs_model_dir = "/data/ai_meisheng_models"
gs_resource_cache_dir = "/tmp/ai_meisheng_resource_cache"
gs_embed_model_path = os.path.join(gs_model_dir, "RawNet3/models/weights/model.pt")
gs_svc_model_path = os.path.join(gs_model_dir,
"weights/xusong_v2_org_version_alldata_embed_spkenx200x_double_e14_s90706.pth")
gs_hubert_model_path = os.path.join(gs_model_dir, "hubert.pt")
gs_rmvpe_model_path = os.path.join(gs_model_dir, "rmvpe.pt")
gs_embed_model_spk_path = os.path.join(gs_model_dir, "SpeakerEncoder/pretrained_model/best_model.pth.tar")
gs_embed_config_spk_path = os.path.join(gs_model_dir, "SpeakerEncoder/pretrained_model/config.json")
# errcode
gs_err_code_success = 0
gs_err_code_download_vocal = 100
gs_err_code_download_svc_url = 101
gs_err_code_svc_process = 102
gs_err_code_transcode = 103
gs_err_code_volume_adjust = 104
gs_err_code_upload = 105
gs_err_code_params = 106
gs_err_code_pending = 107
gs_err_code_target_silence = 108
gs_err_code_too_many_connections = 429
gs_err_code_gender_classify = 430
+gs_err_code_vocal_ratio = 431 #人声占比
+
gs_redis_conf = {
"host": "av-credis.starmaker.co",
"port": 6379,
"pwd": "lKoWEhz%jxTO",
}
# gs_server_redis_conf = {
# "producer": "dev_ai_meisheng_producer", # 输入的队列
# "ai_meisheng_key_prefix": "dev_ai_meisheng_key_", # 存储结果情况
# }
gs_server_redis_conf = {
"producer": "test_ai_meisheng_producer", # 输入的队列
"ai_meisheng_key_prefix": "test_ai_meisheng_key_", # 存储结果情况
}
if gs_prod:
gs_server_redis_conf = {
"producer": "ai_meisheng_producer", # 输入的队列
"ai_meisheng_key_prefix": "ai_meisheng_key_", # 存储结果情况
}
gs_feishu_conf = {
"url": "http://sg-prod-songbook-webmp-1:8000/api/feishu/people",
"users": [
"18810833785", # 杨建利
"17778007843", # 王健军
"18612496315", # 郭子豪
"18600542290" # 方兵晓
]
}
def download2disk(url, dst_path):
try:
urllib.request.urlretrieve(url, dst_path)
return os.path.exists(dst_path)
except Exception as ex:
print(f"download url={url} error", ex)
return False
def exec_cmd(cmd):
# gs_logger.info(cmd)
print(cmd)
ret = os.system(cmd)
if ret != 0:
return False
return True
def exec_cmd_and_result(cmd):
r = os.popen(cmd)
text = r.read()
r.close()
return text
def upload_file2cos(key, file_path, region='ap-singapore', bucket_name='av-audit-sync-sg-1256122840'):
"""
将文件上传到cos
:param key: 桶上的具体地址
:param file_path: 本地文件地址
:param region: 区域
:param bucket_name: 桶地址
:return:
"""
gs_coscmd = "coscmd"
gs_coscmd_conf = "~/.cos.conf"
cmd = "{} -c {} -r {} -b {} upload {} {}".format(gs_coscmd, gs_coscmd_conf, region, bucket_name, file_path, key)
if exec_cmd(cmd):
cmd = "{} -c {} -r {} -b {} info {}".format(gs_coscmd, gs_coscmd_conf, region, bucket_name, key) \
+ "| grep Content-Length |awk \'{print $2}\'"
res_str = exec_cmd_and_result(cmd)
# logging.info("{},res={}".format(key, res_str))
size = float(res_str)
if size > 0:
return True
return False
return False
def check_input(input_data):
key_list = ["record_song_url", "target_url", "start", "end", "vocal_loudness", "female_recording_url",
"male_recording_url"]
for key in key_list:
if key not in input_data.keys():
return False
return True
diff --git a/AIMeiSheng/docker_demo/offline_server.py b/AIMeiSheng/docker_demo/offline_server.py
index f92592e..73f4c18 100644
--- a/AIMeiSheng/docker_demo/offline_server.py
+++ b/AIMeiSheng/docker_demo/offline_server.py
@@ -1,166 +1,170 @@
# -*- coding: UTF-8 -*-
"""
离线处理: 使用redis进行交互,从redis中获取数据资源,在将结果写入到redis
"""
import os
import sys
import time
import json
import socket
import hashlib
from redis_helper import RedisHelper
from cos_helper import CosHelper
from common import *
import logging
from feishu_helper import feishu_send
from svc_online import GSWorkerAttr, SVCOnline, volume_adjustment, svc_offline_logger
sys.path.append(os.path.dirname(__file__))
sys.path.append(os.path.join(os.path.dirname(__file__), "../"))
def download_data(worker_attr):
if os.path.exists(worker_attr.vocal_path):
os.unlink(worker_attr.vocal_path)
st = time.time()
if not download2disk(worker_attr.vocal_url, worker_attr.vocal_path):
return gs_err_code_download_vocal
svc_offline_logger.info(f"download vocal_url={worker_attr.vocal_url} sp = {time.time() - st}")
# download svc_source_url
if not os.path.exists(worker_attr.female_svc_source_path):
st = time.time()
if not download2disk(worker_attr.female_svc_source_url, worker_attr.female_svc_source_path):
return gs_err_code_download_svc_url
svc_offline_logger.info(f"download female_url={worker_attr.female_svc_source_url} sp = {time.time() - st}")
# download svc_source_url
if not os.path.exists(worker_attr.male_svc_source_path):
st = time.time()
if not download2disk(worker_attr.male_svc_source_url, worker_attr.male_svc_source_path):
return gs_err_code_download_svc_url
svc_offline_logger.info(f"download male_url={worker_attr.male_svc_source_url} sp = {time.time() - st}")
return gs_err_code_success
def transcode(wav_path, dst_path):
st = time.time()
cmd = f"ffmpeg -i {wav_path} -ar 44100 -ac 1 -b:a 64k -y {dst_path} -loglevel fatal"
exec_cmd(cmd)
svc_offline_logger.info(f"transcode cmd={cmd}, sp = {time.time() - st}")
return os.path.exists(dst_path)
class OfflineServer:
def __init__(self, redis_conf, server_conf, update_redis=False):
st = time.time()
self.redis_helper = RedisHelper(redis_conf)
self.cos_helper = CosHelper()
self.svc_online = SVCOnline()
self.server_conf = server_conf
self.distinct_key = server_conf["ai_meisheng_key_prefix"]
self.update_redis = update_redis
svc_offline_logger.info(f"config={redis_conf}---server_conf={self.server_conf}")
svc_offline_logger.info(f"offline init finish sp={time.time() - st}")
def exists(self):
return self.redis_helper.exists(self.distinct_key)
def update_result(self, errcode, schedule, gender, target_song_url):
msg = {
"status": errcode,
"schedule": schedule,
"gender": gender,
"target_song_url": target_song_url,
}
# 结果保存15min
if self.update_redis:
self.redis_helper.set(self.distinct_key, json.dumps(msg))
self.redis_helper.expire(self.distinct_key, 60 * 10)
def process_one(self, worker_attr):
self.distinct_key = self.server_conf["ai_meisheng_key_prefix"] + worker_attr.distinct_id
svc_offline_logger.info(f"{worker_attr.log_info_name()}, start download ...")
err = download_data(worker_attr)
if err != gs_err_code_success:
self.update_result(err, 100, "unknown", worker_attr.target_url)
return err, None, None
self.update_result(err, 35, "unknown", worker_attr.target_url)
svc_offline_logger.info(f"{worker_attr.log_info_name()}, start process ...")
gender, err_code = self.svc_online.process(worker_attr)
+ if err_code == gs_err_code_gender_classify:
+ self.update_result(gs_err_code_gender_classify, 100, gender, worker_attr.target_url)
+ return gs_err_code_gender_classify, None, None
if err_code == gs_err_code_target_silence: # unvoice err
+ self.update_result(gs_err_code_target_silence, 100, gender, worker_attr.target_url)
return gs_err_code_target_silence, None, None
if not os.path.exists(worker_attr.target_wav_path):
self.update_result(gs_err_code_svc_process, 100, gender, worker_attr.target_url)
return gs_err_code_svc_process, None, None
self.update_result(err, 85, gender, worker_attr.target_url)
# 音量拉伸到指定响度
svc_offline_logger.info(f"{worker_attr.log_info_name()}, start volume_adjustment ...")
volume_adjustment(worker_attr.target_wav_path, worker_attr.target_loudness, worker_attr.target_wav_ad_path)
if not os.path.exists(worker_attr.target_wav_ad_path):
self.update_result(gs_err_code_volume_adjust, 100, gender, worker_attr.target_url)
return gs_err_code_volume_adjust, None, None
self.update_result(err, 90, gender, worker_attr.target_url)
# transcode
svc_offline_logger.info(f"{worker_attr.log_info_name()}, start transcode ...")
if not transcode(worker_attr.target_wav_ad_path, worker_attr.target_path):
self.update_result(gs_err_code_transcode, 100, gender, worker_attr.target_url)
return gs_err_code_transcode, None, None
self.update_result(err, 95, gender, worker_attr.target_url)
# upload
svc_offline_logger.info(f"{worker_attr.log_info_name()}, start upload_file2cos ...")
st = time.time()
if not self.cos_helper.upload_by_url(worker_attr.target_path, worker_attr.target_url):
self.update_result(gs_err_code_upload, 100, gender, worker_attr.target_url)
return gs_err_code_upload, None, None
self.update_result(gs_err_code_success, 100, gender, worker_attr.target_url)
svc_offline_logger.info(
f"{worker_attr.log_info_name()} upload {worker_attr.target_url} sp = {time.time() - st}")
return gs_err_code_success, worker_attr.target_url, gender
def start_signal(self):
"""
程序启动,发送消息到飞书
:return:
"""
host_name = socket.gethostname()
host_ip = socket.gethostbyname(host_name)
msg = f"{host_name}({host_ip}) offline server start"
feishu_send(gs_feishu_conf["url"], msg, gs_feishu_conf["users"])
def process(self):
self.start_signal()
while True:
data = self.redis_helper.rpop(self.server_conf["producer"])
if data is None:
time.sleep(1)
continue
data = json.loads(data)
if not check_input(data):
svc_offline_logger.error(f"input data error={data}")
continue
worker_attr = GSWorkerAttr(data)
self.distinct_key = self.server_conf["ai_meisheng_key_prefix"] + worker_attr.distinct_id
if not self.exists():
svc_offline_logger.warning(f"input {data}, timeout abandon ....")
worker_attr.rm_cache()
continue
st = time.time()
errcode, target_path, gender = self.process_one(worker_attr)
self.update_result(errcode, 100, gender, target_path)
svc_offline_logger.info(f"{worker_attr.log_info_name()} finish errcode={errcode} sp = {time.time() - st}")
worker_attr.rm_cache()
if __name__ == '__main__':
offline_server = OfflineServer(gs_redis_conf, gs_server_redis_conf, True)
offline_server.process()
diff --git a/AIMeiSheng/docker_demo/svc_online.py b/AIMeiSheng/docker_demo/svc_online.py
index 1606703..421b5f6 100644
--- a/AIMeiSheng/docker_demo/svc_online.py
+++ b/AIMeiSheng/docker_demo/svc_online.py
@@ -1,197 +1,207 @@
# -*- coding: UTF-8 -*-
"""
SVC的核心处理逻辑
"""
import os
import time
import socket
import shutil
import hashlib
from AIMeiSheng.meisheng_svc_final import load_model, process_svc_online
from AIMeiSheng.cos_similar_ui_zoom import cos_similar
from AIMeiSheng.meisheng_env_preparex import meisheng_env_prepare
from AIMeiSheng.voice_classification.online.voice_class_online_fang import VoiceClass, download_volume_balanced
from AIMeiSheng.docker_demo.common import *
import logging
hostname = socket.gethostname()
log_file_name = f"{os.path.dirname(os.path.abspath(__file__))}/av_meisheng_{hostname}.log"
# 设置logger
svc_offline_logger = logging.getLogger("svc_offline")
file_handler = logging.FileHandler(log_file_name)
file_handler.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s', datefmt='%Y-%m-%d %I:%M:%S')
file_handler.setFormatter(formatter)
# if gs_prod:
# svc_offline_logger.addHandler(file_handler)
if os.path.exists(gs_tmp_dir):
shutil.rmtree(gs_tmp_dir)
os.makedirs(gs_model_dir, exist_ok=True)
os.makedirs(gs_resource_cache_dir, exist_ok=True)
# 预设参数
gs_gender_models_url = "https://av-audit-sync-sg-1256122840.cos.ap-singapore.myqcloud.com/hub/voice_classification/models.zip"
gs_volume_bin_url = "https://av-audit-sync-sg-1256122840.cos.ap-singapore.myqcloud.com/dataset/AIMeiSheng/ebur128_tool/v1/ebur128_tool"
class GSWorkerAttr:
def __init__(self, input_data):
# 取出输入资源
vocal_url = input_data["record_song_url"]
target_url = input_data["target_url"]
start = input_data["start"] # 单位是ms
end = input_data["end"] # 单位是ms
vocal_loudness = input_data["vocal_loudness"]
female_recording_url = input_data["female_recording_url"]
male_recording_url = input_data["male_recording_url"]
self.distinct_id = hashlib.md5(vocal_url.encode()).hexdigest()
self.tmp_dir = os.path.join(gs_tmp_dir, self.distinct_id)
if os.path.exists(self.tmp_dir):
shutil.rmtree(self.tmp_dir)
os.makedirs(self.tmp_dir)
self.vocal_url = vocal_url
self.target_url = target_url
ext = vocal_url.split(".")[-1]
self.vocal_path = os.path.join(self.tmp_dir, self.distinct_id + f"_in.{ext}")
self.target_wav_path = os.path.join(self.tmp_dir, self.distinct_id + "_out.wav")
self.target_wav_ad_path = os.path.join(self.tmp_dir, self.distinct_id + "_out_ad.wav")
self.target_path = os.path.join(self.tmp_dir, self.distinct_id + "_out.m4a")
self.female_svc_source_url = female_recording_url
self.male_svc_source_url = male_recording_url
ext = female_recording_url.split(".")[-1]
self.female_svc_source_path = os.path.join(self.tmp_dir, self.distinct_id + f"_female.{ext}")
ext = male_recording_url.split(".")[-1]
self.male_svc_source_path = os.path.join(self.tmp_dir, self.distinct_id + f"_male.{ext}")
# self.female_svc_source_path = os.path.join(gs_resource_cache_dir,
# hashlib.md5(female_recording_url.encode()).hexdigest() + "." + ext)
# ext = male_recording_url.split(".")[-1]
# self.male_svc_source_path = os.path.join(gs_resource_cache_dir,
# hashlib.md5(male_recording_url.encode()).hexdigest() + "." + ext)
self.st_tm = start
self.ed_tm = end
self.target_loudness = vocal_loudness
def log_info_name(self):
return f"d_id={self.distinct_id}, vocal_url={self.vocal_url}"
def rm_cache(self):
if os.path.exists(self.tmp_dir):
shutil.rmtree(self.tmp_dir)
def init_gender_model():
"""
下载模型
:return:
"""
dst_model_dir = os.path.join(gs_model_dir, "voice_classification")
if not os.path.exists(dst_model_dir):
dst_zip_path = os.path.join(gs_model_dir, "models.zip")
if not download2disk(gs_gender_models_url, dst_zip_path):
svc_offline_logger.fatal(f"download gender_model err={gs_gender_models_url}")
cmd = f"cd {gs_model_dir}; unzip {dst_zip_path}; mv models voice_classification; rm -f {dst_zip_path}"
os.system(cmd)
if not os.path.exists(dst_model_dir):
svc_offline_logger.fatal(f"unzip {dst_zip_path} err")
music_voice_pure_model = os.path.join(dst_model_dir, "voice_005_rec_v5.pth")
music_voice_no_pure_model = os.path.join(dst_model_dir, "voice_10_v5.pth")
gender_pure_model = os.path.join(dst_model_dir, "gender_8k_ratev5_v6_adam.pth")
gender_no_pure_model = os.path.join(dst_model_dir, "gender_8k_v6_adam.pth")
vc = VoiceClass(music_voice_pure_model, music_voice_no_pure_model, gender_pure_model, gender_no_pure_model)
return vc
def init_svc_model():
meisheng_env_prepare(logging, gs_model_dir)
embed_model, hubert_model = load_model()
cs_sim = cos_similar()
return embed_model, hubert_model, cs_sim
def download_volume_adjustment():
"""
下载音量调整工具
:return:
"""
volume_bin_path = os.path.join(gs_model_dir, "ebur128_tool")
if not os.path.exists(volume_bin_path):
if not download2disk(gs_volume_bin_url, volume_bin_path):
svc_offline_logger.fatal(f"download volume_bin err={gs_volume_bin_url}")
os.system(f"chmod +x {volume_bin_path}")
def volume_adjustment(wav_path, target_loudness, out_path):
"""
音量调整
:param wav_path:
:param target_loudness:
:param out_path:
:return:
"""
volume_bin_path = os.path.join(gs_model_dir, "ebur128_tool")
cmd = f"{volume_bin_path} {wav_path} {target_loudness} {out_path}"
os.system(cmd)
class SVCOnline:
def __init__(self):
st = time.time()
self.gender_model = init_gender_model()
self.embed_model, self.hubert_model, self.cs_sim = init_svc_model()
download_volume_adjustment()
download_volume_balanced()
svc_offline_logger.info(f"svc init finished, sp = {time.time() - st}")
def gender_process(self, worker_attr):
st = time.time()
gender, female_rate, is_pure = self.gender_model.process(worker_attr.vocal_path)
+
svc_offline_logger.info(
f"{worker_attr.vocal_url}, gender={gender}, female_rate={female_rate}, is_pure={is_pure}, "
f"gender_process sp = {time.time() - st}")
if gender == 0:
gender = 'female'
elif gender == 1:
gender = 'male'
elif female_rate == None:
gender = 'male'
return gender, gs_err_code_gender_classify
elif female_rate > 0.5:
gender = 'female'
else:
gender = 'male'
+ if gender == 'female':
+ if self.gender_model.vocal_ratio < 0.5:
+ print(f"@@@ vocal_ratio: {self.gender_model.vocal_ratio}, gender : {gender}, gs_err_code_vocal_ratio : {gs_err_code_vocal_ratio}")
+ return gender, gs_err_code_vocal_ratio
+ else:
+ if self.gender_model.vocal_ratio < 0.6:
+ print(f"@@@ vocal_ratio: {self.gender_model.vocal_ratio}, gender : {gender}, gs_err_code_vocal_ratio : {gs_err_code_vocal_ratio}")
+ return gender, gs_err_code_vocal_ratio
+
svc_offline_logger.info(f"{worker_attr.vocal_url}, modified gender={gender}")
# err = gs_err_code_success
# if female_rate == -1:
# err = gs_err_code_target_silence
return gender, gs_err_code_success
def process(self, worker_attr):
gender, err = self.gender_process(worker_attr)
if err != gs_err_code_success:
return gender, err
song_path = worker_attr.female_svc_source_path
if gender == "male":
song_path = worker_attr.male_svc_source_path
params = {'gender': gender, 'tst': worker_attr.st_tm, "tnd": worker_attr.ed_tm, 'delay': 0, 'song_path': None}
st = time.time()
err_code = process_svc_online(song_path, worker_attr.vocal_path, worker_attr.target_wav_path, self.embed_model,
self.hubert_model, self.cs_sim, params)
svc_offline_logger.info(f"{worker_attr.vocal_url}, err_code={err_code} process svc sp = {time.time() - st}")
return gender, err_code
diff --git a/AIMeiSheng/voice_classification/online/voice_class_online_fang.py b/AIMeiSheng/voice_classification/online/voice_class_online_fang.py
index d792db0..a74173d 100644
--- a/AIMeiSheng/voice_classification/online/voice_class_online_fang.py
+++ b/AIMeiSheng/voice_classification/online/voice_class_online_fang.py
@@ -1,444 +1,446 @@
"""
男女声分类在线工具
1 转码为16bit单声道
2 均衡化
3 模型分类
"""
import os
import sys
import librosa
import shutil
import logging
import time
import torch.nn.functional as F
import numpy as np
from model import *
# from common import bind_kernel
logging.basicConfig(level=logging.INFO)
os.environ["LRU_CACHE_CAPACITY"] = "1"
# torch.set_num_threads(1)
# bind_kernel(1)
"""
临时用一下,全局使用的变量
"""
transcode_time = 0
vb_time = 0
mfcc_time = 0
predict_time = 0
"""
错误码
"""
ERR_CODE_SUCCESS = 0 # 处理成功
ERR_CODE_NO_FILE = -1 # 文件不存在
ERR_CODE_TRANSCODE = -2 # 转码失败
ERR_CODE_VOLUME_BALANCED = -3 # 均衡化失败
ERR_CODE_FEATURE_TOO_SHORT = -4 # 特征文件太短
"""
常量
"""
FRAME_LEN = 128
MFCC_LEN = 80
gs_bin_url = "https://av-audit-sync-sg-1256122840.cos.ap-singapore.myqcloud.com/hub/voice_classification/bin/bin.zip"
EBUR128_BIN = "/tmp/voice_class_bin/standard_audio_no_cut"
# EBUR128_BIN = "/Users/yangjianli/linux/opt/soft/bin/standard_audio_no_cut"
GENDER_FEMALE = 0
GENDER_MALE = 1
GENDER_OTHER = 2
"""
通用函数
"""
def exec_cmd(cmd):
ret = os.system(cmd)
if ret != 0:
return False
return True
"""
业务需要的函数
"""
def get_one_mfcc(file_url):
st = time.time()
data, sr = librosa.load(file_url, sr=16000)
if len(data) < 512:
return []
mfcc = librosa.feature.mfcc(y=data, sr=sr, n_fft=512, hop_length=256, n_mfcc=MFCC_LEN)
mfcc = mfcc.transpose()
print("get_one_mfcc:spend_time={}".format(time.time() - st))
global mfcc_time
mfcc_time += time.time() - st
return mfcc
def download_volume_balanced():
import urllib.request
if not os.path.exists(EBUR128_BIN):
dst_path = "/tmp/bin.zip"
urllib.request.urlretrieve(gs_bin_url, dst_path)
if not os.path.exists(dst_path):
print(f"download dst_path={gs_bin_url} err!")
exit(-1)
dirname = os.path.dirname(dst_path)
cmd = f"cd {dirname}; unzip bin.zip; rm -f bin.zip; mv bin voice_class_bin"
os.system(cmd)
if not os.path.exists(EBUR128_BIN):
print(f"exec {cmd} err!")
exit(-1)
def volume_balanced(src, dst):
st = time.time()
download_volume_balanced()
cmd = "{} {} {}".format(EBUR128_BIN, src, dst)
logging.info(cmd)
exec_cmd(cmd)
if not os.path.exists(dst):
logging.error("volume_balanced:cmd={}".format(cmd))
print("volume_balanced:spend_time={}".format(time.time() - st))
global vb_time
vb_time += time.time() - st
return os.path.exists(dst)
def transcode(src, dst):
st = time.time()
cmd = "ffmpeg -loglevel quiet -i {} -ar 16000 -ac 1 {}".format(src, dst)
logging.info(cmd)
exec_cmd(cmd)
if not os.path.exists(dst):
logging.error("transcode:cmd={}".format(cmd))
print("transcode:spend_time={}".format(time.time() - st))
global transcode_time
transcode_time += time.time() - st
return os.path.exists(dst)
class VoiceClass:
def __init__(self, music_voice_pure_model, music_voice_no_pure_model, gender_pure_model, gender_no_pure_model):
"""
四个模型
:param music_voice_pure_model: 分辨纯净人声/其他
:param music_voice_no_pure_model: 分辨有人声/其他
:param gender_pure_model: 纯净人声分辨男女
:param gender_no_pure_model: 有人声分辨男女
"""
st = time.time()
self.device = "cpu"
self.batch_size = 256
self.music_voice_pure_model = load_model(MusicVoiceV5Model, music_voice_pure_model, self.device)
self.music_voice_no_pure_model = load_model(MusicVoiceV5Model, music_voice_no_pure_model, self.device)
self.gender_pure_model = load_model(MobileNetV2Gender, gender_pure_model, self.device)
self.gender_no_pure_model = load_model(MobileNetV2Gender, gender_no_pure_model, self.device)
logging.info("load model ok ! spend_time={}".format(time.time() - st))
+ self.vocal_ratio = 1
def batch_predict(self, model, features):
st = time.time()
scores = []
with torch.no_grad():
for i in range(0, len(features), self.batch_size):
cur_data = features[i:i + self.batch_size].to(self.device)
predicts = model(cur_data)
predicts_score = F.softmax(predicts, dim=1)
scores.extend(predicts_score.cpu().numpy())
ret = np.array(scores)
global predict_time
predict_time += time.time() - st
return ret
def predict_pure(self, filename, features):
scores = self.batch_predict(self.music_voice_pure_model, features)
new_features = []
for idx, score in enumerate(scores):
if score[0] > 0.5: # 非人声
continue
new_features.append(features[idx].numpy())
# 人声段太少,不能进行处理
# 参数可以改
new_feature_len = len(new_features)
new_feature_rate = len(new_features) / len(features)
+ self.vocal_ratio = new_feature_rate
# 修改人声占比低一些
if new_feature_len < 4 or new_feature_rate < 0.05:
logging.warning(
"filename={}|predict_pure|other|len={}|rate={}".format(filename, new_feature_len, new_feature_rate)
)
return GENDER_OTHER, -1
new_features = torch.from_numpy(np.array(new_features))
scores = self.batch_predict(self.gender_pure_model, new_features)
f_avg = sum(scores[:, 0]) / len(scores)
m_avg = sum(scores[:, 1]) / len(scores)
female_rate = f_avg / (f_avg + m_avg)
if female_rate > 0.65:
return GENDER_FEMALE, female_rate
if female_rate < 0.12:
return GENDER_MALE, female_rate
logging.warning(
"filename={}|predict_pure|other|len={}|rate={}".format(filename, new_feature_len, new_feature_rate)
)
return GENDER_OTHER, female_rate
def predict_no_pure(self, filename, features):
scores = self.batch_predict(self.music_voice_no_pure_model, features)
new_features = []
for idx, score in enumerate(scores):
if score[0] > 0.5: # 非人声
continue
new_features.append(features[idx].numpy())
# 人声段太少,不能进行处理
# 参数可以改
new_feature_len = len(new_features)
new_feature_rate = len(new_features) / len(features)
# 修改人声占比低一些
if new_feature_len < 4 or new_feature_rate < 0.05:
logging.warning(
"filename={}|predict_no_pure|other|len={}|rate={}".format(filename, new_feature_len, new_feature_rate)
)
return GENDER_OTHER, -1
new_features = torch.from_numpy(np.array(new_features))
scores = self.batch_predict(self.gender_no_pure_model, new_features)
f_avg = sum(scores[:, 0]) / len(scores)
m_avg = sum(scores[:, 1]) / len(scores)
female_rate = f_avg / (f_avg + m_avg)
if female_rate > 0.75:
return GENDER_FEMALE, female_rate
if female_rate < 0.1:
return GENDER_MALE, female_rate
logging.warning(
"filename={}|predict_no_pure|other|len={}|rate={}".format(filename, new_feature_len, new_feature_rate)
)
return GENDER_OTHER, female_rate
def predict(self, filename, features):
st = time.time()
new_features = []
for i in range(FRAME_LEN, len(features), FRAME_LEN):
new_features.append(features[i - FRAME_LEN: i])
new_features = torch.from_numpy(np.array(new_features))
gender, rate = self.predict_pure(filename, new_features)
if gender == GENDER_OTHER:
logging.info("start no pure process...")
gender, rate = self.predict_no_pure(filename, new_features)
return gender, rate, False
print("predict|spend_time={}".format(time.time() - st))
return gender, rate, True
def process_one_logic(self, filename, file_path, cache_dir):
tmp_wav = os.path.join(cache_dir, "tmp.wav")
tmp_vb_wav = os.path.join(cache_dir, "tmp_vb.wav")
if not transcode(file_path, tmp_wav):
return ERR_CODE_TRANSCODE, None, None
if not volume_balanced(tmp_wav, tmp_vb_wav):
return ERR_CODE_VOLUME_BALANCED, None, None
features = get_one_mfcc(tmp_vb_wav)
if len(features) < FRAME_LEN:
logging.error("feature too short|file_path={}".format(file_path))
return ERR_CODE_FEATURE_TOO_SHORT, None, None
return self.predict(filename, features)
def process_one(self, file_path):
base_dir = os.path.dirname(file_path)
filename = os.path.splitext(file_path)[0]
print("filename:", filename)
cache_dir = os.path.join(base_dir, filename + "_cache")
if os.path.exists(cache_dir):
shutil.rmtree(cache_dir)
os.makedirs(cache_dir)
ret = self.process_one_logic(filename, file_path, cache_dir)
shutil.rmtree(cache_dir)
return ret
def process(self, file_path):
gender, female_rate, is_pure = self.process_one(file_path)
logging.info("{}|gender={}|female_rate={}".format(file_path, gender, female_rate))
return gender, female_rate, is_pure
def process_by_feature(self, feature_file):
"""
直接处理特征文件
:param feature_file:
:return:
"""
filename = os.path.splitext(feature_file)[0]
features = np.load(feature_file)
gender, female_rate = self.predict(filename, features)
return gender, female_rate
def test_all_feature():
import glob
base_dir = "/data/datasets/music_voice_dataset_full/feature_online_data_v3"
female = glob.glob(os.path.join(base_dir, "female/*feature.npy"))
male = glob.glob(os.path.join(base_dir, "male/*feature.npy"))
other = glob.glob(os.path.join(base_dir, "other/*feature.npy"))
model_path = "/data/jianli.yang/voice_classification/online/models"
music_voice_pure_model = os.path.join(model_path, "voice_005_rec_v5.pth")
music_voice_no_pure_model = os.path.join(model_path, "voice_10_v5.pth")
gender_pure_model = os.path.join(model_path, "gender_8k_ratev5_v6_adam.pth")
gender_no_pure_model = os.path.join(model_path, "gender_8k_v6_adam.pth")
vc = VoiceClass(music_voice_pure_model, music_voice_no_pure_model, gender_pure_model, gender_no_pure_model)
tot_st = time.time()
ret_map = {
0: {0: 0, 1: 0, 2: 0},
1: {0: 0, 1: 0, 2: 0},
2: {0: 0, 1: 0, 2: 0}
}
for file in female:
st = time.time()
print("------------------------------>>>>>")
gender, female_score = vc.process_by_feature(file)
ret_map[0][gender] += 1
if gender != 0:
print("err:female->{}|{}|{}".format(gender, file, female_score))
print("process|spend_tm=={}".format(time.time() - st))
for file in male:
st = time.time()
print("------------------------------>>>>>")
gender, female_score = vc.process_by_feature(file)
ret_map[1][gender] += 1
if gender != 1:
print("err:male->{}|{}|{}".format(gender, file, female_score))
print("process|spend_tm=={}".format(time.time() - st))
for file in other:
st = time.time()
print("------------------------------>>>>>")
gender, female_score = vc.process_by_feature(file)
ret_map[2][gender] += 1
if gender != 2:
print("err:other->{}|{}|{}".format(gender, file, female_score))
print("process|spend_tm=={}".format(time.time() - st))
global transcode_time, vb_time, mfcc_time, predict_time
print("spend_time:tot={}|transcode={}|vb={}|gen_feature={}|predict={}".format(time.time() - tot_st, transcode_time,
vb_time, mfcc_time, predict_time))
f_f = ret_map[0][0]
f_m = ret_map[0][1]
f_o = ret_map[0][2]
m_f = ret_map[1][0]
m_m = ret_map[1][1]
m_o = ret_map[1][2]
o_f = ret_map[2][0]
o_m = ret_map[2][1]
o_o = ret_map[2][2]
print("ff:{},fm:{},fo:{}".format(f_f, f_m, f_o))
print("mm:{},mf:{},mo:{}".format(m_m, m_f, m_o))
print("om:{},of:{},oo:{}".format(o_m, o_f, o_o))
# 女性准确率和召回率
f_acc = f_f / (f_f + m_f + o_f)
f_recall = f_f / (f_f + f_m + f_o)
# 男性准确率和召回率
m_acc = m_m / (m_m + f_m + o_m)
m_recall = m_m / (m_m + m_f + m_o)
print("female: acc={}|recall={}".format(f_acc, f_recall))
print("male: acc={}|recall={}".format(m_acc, m_recall))
def test_all():
import glob
base_dir = "/data/datasets/music_voice_dataset_full/online_data_v3_top200"
female = glob.glob(os.path.join(base_dir, "female/*mp4"))
male = glob.glob(os.path.join(base_dir, "male/*mp4"))
other = glob.glob(os.path.join(base_dir, "other/*mp4"))
model_path = "/data/jianli.yang/voice_classification/online/models"
music_voice_pure_model = os.path.join(model_path, "voice_005_rec_v5.pth")
music_voice_no_pure_model = os.path.join(model_path, "voice_10_v5.pth")
gender_pure_model = os.path.join(model_path, "gender_8k_ratev5_v6_adam.pth")
gender_no_pure_model = os.path.join(model_path, "gender_8k_v6_adam.pth")
vc = VoiceClass(music_voice_pure_model, music_voice_no_pure_model, gender_pure_model, gender_no_pure_model)
tot_st = time.time()
ret_map = {
0: {0: 0, 1: 0, 2: 0},
1: {0: 0, 1: 0, 2: 0},
2: {0: 0, 1: 0, 2: 0}
}
for file in female:
st = time.time()
print("------------------------------>>>>>")
gender, female_score = vc.process(file)
ret_map[0][gender] += 1
if gender != 0:
print("err:female->{}|{}|{}".format(gender, file, female_score))
print("process|spend_tm=={}".format(time.time() - st))
for file in male:
st = time.time()
print("------------------------------>>>>>")
gender, female_score = vc.process(file)
ret_map[1][gender] += 1
if gender != 1:
print("err:male->{}|{}|{}".format(gender, file, female_score))
print("process|spend_tm=={}".format(time.time() - st))
for file in other:
st = time.time()
print("------------------------------>>>>>")
gender, female_score = vc.process(file)
ret_map[2][gender] += 1
if gender != 2:
print("err:other->{}|{}|{}".format(gender, file, female_score))
print("process|spend_tm=={}".format(time.time() - st))
global transcode_time, vb_time, mfcc_time, predict_time
print("spend_time:tot={}|transcode={}|vb={}|gen_feature={}|predict={}".format(time.time() - tot_st, transcode_time,
vb_time, mfcc_time, predict_time))
f_f = ret_map[0][0]
f_m = ret_map[0][1]
f_o = ret_map[0][2]
m_f = ret_map[1][0]
m_m = ret_map[1][1]
m_o = ret_map[1][2]
o_f = ret_map[2][0]
o_m = ret_map[2][1]
o_o = ret_map[2][2]
print("ff:{},fm:{},fo:{}".format(f_f, f_m, f_o))
print("mm:{},mf:{},mo:{}".format(m_m, m_f, m_o))
print("om:{},of:{},oo:{}".format(o_m, o_f, o_o))
# 女性准确率和召回率
f_acc = f_f / (f_f + m_f + o_f)
f_recall = f_f / (f_f + f_m + f_o)
# 男性准确率和召回率
m_acc = m_m / (m_m + f_m + o_m)
m_recall = m_m / (m_m + m_f + m_o)
print("female: acc={}|recall={}".format(f_acc, f_recall))
print("male: acc={}|recall={}".format(m_acc, m_recall))
if __name__ == "__main__":
# test_all()
# test_all_feature()
model_path = sys.argv[1]
voice_path = sys.argv[2]
music_voice_pure_model = os.path.join(model_path, "voice_005_rec_v5.pth")
music_voice_no_pure_model = os.path.join(model_path, "voice_10_v5.pth")
gender_pure_model = os.path.join(model_path, "gender_8k_ratev5_v6_adam.pth")
gender_no_pure_model = os.path.join(model_path, "gender_8k_v6_adam.pth")
vc = VoiceClass(music_voice_pure_model, music_voice_no_pure_model, gender_pure_model, gender_no_pure_model)
for i in range(0, 1):
st = time.time()
print("------------------------------>>>>>")
gender, female_rate, is_pure = vc.process(voice_path)
print("process|spend_tm=={}".format(time.time() - st))
print("gender:{}, female_rate:{},is_pure:{}".format(gender, female_rate, is_pure))

File Metadata

Mime Type
text/x-diff
Expires
Sat, Nov 23, 18:14 (1 d, 13 h)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
1325561
Default Alt Text
(38 KB)

Event Timeline