diff --git a/AIMeiSheng/check_crash.py b/AIMeiSheng/check_crash.py new file mode 100644 index 0000000..e69de29 diff --git a/AIMeiSheng/docker_demo/common.py b/AIMeiSheng/docker_demo/common.py index 6d3d716..49b3628 100644 --- a/AIMeiSheng/docker_demo/common.py +++ b/AIMeiSheng/docker_demo/common.py @@ -1,128 +1,130 @@ import os import sys import time # import logging import urllib, urllib.request # 测试/正式环境 gs_prod = True # if len(sys.argv) > 1 and sys.argv[1] == "prod": # gs_prod = True # print(gs_prod) gs_tmp_dir = "/data/ai_meisheng_tmp" gs_model_dir = "/data/ai_meisheng_models" gs_resource_cache_dir = "/tmp/ai_meisheng_resource_cache" gs_embed_model_path = os.path.join(gs_model_dir, "RawNet3/models/weights/model.pt") gs_svc_model_path = os.path.join(gs_model_dir, "weights/xusong_v2_org_version_alldata_embed_spkenx200x_double_e14_s90706.pth") gs_hubert_model_path = os.path.join(gs_model_dir, "hubert.pt") gs_rmvpe_model_path = os.path.join(gs_model_dir, "rmvpe.pt") gs_embed_model_spk_path = os.path.join(gs_model_dir, "SpeakerEncoder/pretrained_model/best_model.pth.tar") gs_embed_config_spk_path = os.path.join(gs_model_dir, "SpeakerEncoder/pretrained_model/config.json") # errcode gs_err_code_success = 0 gs_err_code_download_vocal = 100 gs_err_code_download_svc_url = 101 gs_err_code_svc_process = 102 gs_err_code_transcode = 103 gs_err_code_volume_adjust = 104 gs_err_code_upload = 105 gs_err_code_params = 106 gs_err_code_pending = 107 gs_err_code_target_silence = 108 gs_err_code_too_many_connections = 429 gs_err_code_gender_classify = 430 +gs_err_code_vocal_ratio = 431 #人声占比 + gs_redis_conf = { "host": "av-credis.starmaker.co", "port": 6379, "pwd": "lKoWEhz%jxTO", } # gs_server_redis_conf = { # "producer": "dev_ai_meisheng_producer", # 输入的队列 # "ai_meisheng_key_prefix": "dev_ai_meisheng_key_", # 存储结果情况 # } gs_server_redis_conf = { "producer": "test_ai_meisheng_producer", # 输入的队列 "ai_meisheng_key_prefix": "test_ai_meisheng_key_", # 存储结果情况 } if gs_prod: gs_server_redis_conf = { "producer": "ai_meisheng_producer", # 输入的队列 "ai_meisheng_key_prefix": "ai_meisheng_key_", # 存储结果情况 } gs_feishu_conf = { "url": "http://sg-prod-songbook-webmp-1:8000/api/feishu/people", "users": [ "18810833785", # 杨建利 "17778007843", # 王健军 "18612496315", # 郭子豪 "18600542290" # 方兵晓 ] } def download2disk(url, dst_path): try: urllib.request.urlretrieve(url, dst_path) return os.path.exists(dst_path) except Exception as ex: print(f"download url={url} error", ex) return False def exec_cmd(cmd): # gs_logger.info(cmd) print(cmd) ret = os.system(cmd) if ret != 0: return False return True def exec_cmd_and_result(cmd): r = os.popen(cmd) text = r.read() r.close() return text def upload_file2cos(key, file_path, region='ap-singapore', bucket_name='av-audit-sync-sg-1256122840'): """ 将文件上传到cos :param key: 桶上的具体地址 :param file_path: 本地文件地址 :param region: 区域 :param bucket_name: 桶地址 :return: """ gs_coscmd = "coscmd" gs_coscmd_conf = "~/.cos.conf" cmd = "{} -c {} -r {} -b {} upload {} {}".format(gs_coscmd, gs_coscmd_conf, region, bucket_name, file_path, key) if exec_cmd(cmd): cmd = "{} -c {} -r {} -b {} info {}".format(gs_coscmd, gs_coscmd_conf, region, bucket_name, key) \ + "| grep Content-Length |awk \'{print $2}\'" res_str = exec_cmd_and_result(cmd) # logging.info("{},res={}".format(key, res_str)) size = float(res_str) if size > 0: return True return False return False def check_input(input_data): key_list = ["record_song_url", "target_url", "start", "end", "vocal_loudness", "female_recording_url", "male_recording_url"] for key in key_list: if key not in input_data.keys(): return False return True diff --git a/AIMeiSheng/docker_demo/offline_server.py b/AIMeiSheng/docker_demo/offline_server.py index f92592e..73f4c18 100644 --- a/AIMeiSheng/docker_demo/offline_server.py +++ b/AIMeiSheng/docker_demo/offline_server.py @@ -1,166 +1,170 @@ # -*- coding: UTF-8 -*- """ 离线处理: 使用redis进行交互,从redis中获取数据资源,在将结果写入到redis """ import os import sys import time import json import socket import hashlib from redis_helper import RedisHelper from cos_helper import CosHelper from common import * import logging from feishu_helper import feishu_send from svc_online import GSWorkerAttr, SVCOnline, volume_adjustment, svc_offline_logger sys.path.append(os.path.dirname(__file__)) sys.path.append(os.path.join(os.path.dirname(__file__), "../")) def download_data(worker_attr): if os.path.exists(worker_attr.vocal_path): os.unlink(worker_attr.vocal_path) st = time.time() if not download2disk(worker_attr.vocal_url, worker_attr.vocal_path): return gs_err_code_download_vocal svc_offline_logger.info(f"download vocal_url={worker_attr.vocal_url} sp = {time.time() - st}") # download svc_source_url if not os.path.exists(worker_attr.female_svc_source_path): st = time.time() if not download2disk(worker_attr.female_svc_source_url, worker_attr.female_svc_source_path): return gs_err_code_download_svc_url svc_offline_logger.info(f"download female_url={worker_attr.female_svc_source_url} sp = {time.time() - st}") # download svc_source_url if not os.path.exists(worker_attr.male_svc_source_path): st = time.time() if not download2disk(worker_attr.male_svc_source_url, worker_attr.male_svc_source_path): return gs_err_code_download_svc_url svc_offline_logger.info(f"download male_url={worker_attr.male_svc_source_url} sp = {time.time() - st}") return gs_err_code_success def transcode(wav_path, dst_path): st = time.time() cmd = f"ffmpeg -i {wav_path} -ar 44100 -ac 1 -b:a 64k -y {dst_path} -loglevel fatal" exec_cmd(cmd) svc_offline_logger.info(f"transcode cmd={cmd}, sp = {time.time() - st}") return os.path.exists(dst_path) class OfflineServer: def __init__(self, redis_conf, server_conf, update_redis=False): st = time.time() self.redis_helper = RedisHelper(redis_conf) self.cos_helper = CosHelper() self.svc_online = SVCOnline() self.server_conf = server_conf self.distinct_key = server_conf["ai_meisheng_key_prefix"] self.update_redis = update_redis svc_offline_logger.info(f"config={redis_conf}---server_conf={self.server_conf}") svc_offline_logger.info(f"offline init finish sp={time.time() - st}") def exists(self): return self.redis_helper.exists(self.distinct_key) def update_result(self, errcode, schedule, gender, target_song_url): msg = { "status": errcode, "schedule": schedule, "gender": gender, "target_song_url": target_song_url, } # 结果保存15min if self.update_redis: self.redis_helper.set(self.distinct_key, json.dumps(msg)) self.redis_helper.expire(self.distinct_key, 60 * 10) def process_one(self, worker_attr): self.distinct_key = self.server_conf["ai_meisheng_key_prefix"] + worker_attr.distinct_id svc_offline_logger.info(f"{worker_attr.log_info_name()}, start download ...") err = download_data(worker_attr) if err != gs_err_code_success: self.update_result(err, 100, "unknown", worker_attr.target_url) return err, None, None self.update_result(err, 35, "unknown", worker_attr.target_url) svc_offline_logger.info(f"{worker_attr.log_info_name()}, start process ...") gender, err_code = self.svc_online.process(worker_attr) + if err_code == gs_err_code_gender_classify: + self.update_result(gs_err_code_gender_classify, 100, gender, worker_attr.target_url) + return gs_err_code_gender_classify, None, None if err_code == gs_err_code_target_silence: # unvoice err + self.update_result(gs_err_code_target_silence, 100, gender, worker_attr.target_url) return gs_err_code_target_silence, None, None if not os.path.exists(worker_attr.target_wav_path): self.update_result(gs_err_code_svc_process, 100, gender, worker_attr.target_url) return gs_err_code_svc_process, None, None self.update_result(err, 85, gender, worker_attr.target_url) # 音量拉伸到指定响度 svc_offline_logger.info(f"{worker_attr.log_info_name()}, start volume_adjustment ...") volume_adjustment(worker_attr.target_wav_path, worker_attr.target_loudness, worker_attr.target_wav_ad_path) if not os.path.exists(worker_attr.target_wav_ad_path): self.update_result(gs_err_code_volume_adjust, 100, gender, worker_attr.target_url) return gs_err_code_volume_adjust, None, None self.update_result(err, 90, gender, worker_attr.target_url) # transcode svc_offline_logger.info(f"{worker_attr.log_info_name()}, start transcode ...") if not transcode(worker_attr.target_wav_ad_path, worker_attr.target_path): self.update_result(gs_err_code_transcode, 100, gender, worker_attr.target_url) return gs_err_code_transcode, None, None self.update_result(err, 95, gender, worker_attr.target_url) # upload svc_offline_logger.info(f"{worker_attr.log_info_name()}, start upload_file2cos ...") st = time.time() if not self.cos_helper.upload_by_url(worker_attr.target_path, worker_attr.target_url): self.update_result(gs_err_code_upload, 100, gender, worker_attr.target_url) return gs_err_code_upload, None, None self.update_result(gs_err_code_success, 100, gender, worker_attr.target_url) svc_offline_logger.info( f"{worker_attr.log_info_name()} upload {worker_attr.target_url} sp = {time.time() - st}") return gs_err_code_success, worker_attr.target_url, gender def start_signal(self): """ 程序启动,发送消息到飞书 :return: """ host_name = socket.gethostname() host_ip = socket.gethostbyname(host_name) msg = f"{host_name}({host_ip}) offline server start" feishu_send(gs_feishu_conf["url"], msg, gs_feishu_conf["users"]) def process(self): self.start_signal() while True: data = self.redis_helper.rpop(self.server_conf["producer"]) if data is None: time.sleep(1) continue data = json.loads(data) if not check_input(data): svc_offline_logger.error(f"input data error={data}") continue worker_attr = GSWorkerAttr(data) self.distinct_key = self.server_conf["ai_meisheng_key_prefix"] + worker_attr.distinct_id if not self.exists(): svc_offline_logger.warning(f"input {data}, timeout abandon ....") worker_attr.rm_cache() continue st = time.time() errcode, target_path, gender = self.process_one(worker_attr) self.update_result(errcode, 100, gender, target_path) svc_offline_logger.info(f"{worker_attr.log_info_name()} finish errcode={errcode} sp = {time.time() - st}") worker_attr.rm_cache() if __name__ == '__main__': offline_server = OfflineServer(gs_redis_conf, gs_server_redis_conf, True) offline_server.process() diff --git a/AIMeiSheng/docker_demo/svc_online.py b/AIMeiSheng/docker_demo/svc_online.py index 1606703..421b5f6 100644 --- a/AIMeiSheng/docker_demo/svc_online.py +++ b/AIMeiSheng/docker_demo/svc_online.py @@ -1,197 +1,207 @@ # -*- coding: UTF-8 -*- """ SVC的核心处理逻辑 """ import os import time import socket import shutil import hashlib from AIMeiSheng.meisheng_svc_final import load_model, process_svc_online from AIMeiSheng.cos_similar_ui_zoom import cos_similar from AIMeiSheng.meisheng_env_preparex import meisheng_env_prepare from AIMeiSheng.voice_classification.online.voice_class_online_fang import VoiceClass, download_volume_balanced from AIMeiSheng.docker_demo.common import * import logging hostname = socket.gethostname() log_file_name = f"{os.path.dirname(os.path.abspath(__file__))}/av_meisheng_{hostname}.log" # 设置logger svc_offline_logger = logging.getLogger("svc_offline") file_handler = logging.FileHandler(log_file_name) file_handler.setLevel(logging.INFO) formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s', datefmt='%Y-%m-%d %I:%M:%S') file_handler.setFormatter(formatter) # if gs_prod: # svc_offline_logger.addHandler(file_handler) if os.path.exists(gs_tmp_dir): shutil.rmtree(gs_tmp_dir) os.makedirs(gs_model_dir, exist_ok=True) os.makedirs(gs_resource_cache_dir, exist_ok=True) # 预设参数 gs_gender_models_url = "https://av-audit-sync-sg-1256122840.cos.ap-singapore.myqcloud.com/hub/voice_classification/models.zip" gs_volume_bin_url = "https://av-audit-sync-sg-1256122840.cos.ap-singapore.myqcloud.com/dataset/AIMeiSheng/ebur128_tool/v1/ebur128_tool" class GSWorkerAttr: def __init__(self, input_data): # 取出输入资源 vocal_url = input_data["record_song_url"] target_url = input_data["target_url"] start = input_data["start"] # 单位是ms end = input_data["end"] # 单位是ms vocal_loudness = input_data["vocal_loudness"] female_recording_url = input_data["female_recording_url"] male_recording_url = input_data["male_recording_url"] self.distinct_id = hashlib.md5(vocal_url.encode()).hexdigest() self.tmp_dir = os.path.join(gs_tmp_dir, self.distinct_id) if os.path.exists(self.tmp_dir): shutil.rmtree(self.tmp_dir) os.makedirs(self.tmp_dir) self.vocal_url = vocal_url self.target_url = target_url ext = vocal_url.split(".")[-1] self.vocal_path = os.path.join(self.tmp_dir, self.distinct_id + f"_in.{ext}") self.target_wav_path = os.path.join(self.tmp_dir, self.distinct_id + "_out.wav") self.target_wav_ad_path = os.path.join(self.tmp_dir, self.distinct_id + "_out_ad.wav") self.target_path = os.path.join(self.tmp_dir, self.distinct_id + "_out.m4a") self.female_svc_source_url = female_recording_url self.male_svc_source_url = male_recording_url ext = female_recording_url.split(".")[-1] self.female_svc_source_path = os.path.join(self.tmp_dir, self.distinct_id + f"_female.{ext}") ext = male_recording_url.split(".")[-1] self.male_svc_source_path = os.path.join(self.tmp_dir, self.distinct_id + f"_male.{ext}") # self.female_svc_source_path = os.path.join(gs_resource_cache_dir, # hashlib.md5(female_recording_url.encode()).hexdigest() + "." + ext) # ext = male_recording_url.split(".")[-1] # self.male_svc_source_path = os.path.join(gs_resource_cache_dir, # hashlib.md5(male_recording_url.encode()).hexdigest() + "." + ext) self.st_tm = start self.ed_tm = end self.target_loudness = vocal_loudness def log_info_name(self): return f"d_id={self.distinct_id}, vocal_url={self.vocal_url}" def rm_cache(self): if os.path.exists(self.tmp_dir): shutil.rmtree(self.tmp_dir) def init_gender_model(): """ 下载模型 :return: """ dst_model_dir = os.path.join(gs_model_dir, "voice_classification") if not os.path.exists(dst_model_dir): dst_zip_path = os.path.join(gs_model_dir, "models.zip") if not download2disk(gs_gender_models_url, dst_zip_path): svc_offline_logger.fatal(f"download gender_model err={gs_gender_models_url}") cmd = f"cd {gs_model_dir}; unzip {dst_zip_path}; mv models voice_classification; rm -f {dst_zip_path}" os.system(cmd) if not os.path.exists(dst_model_dir): svc_offline_logger.fatal(f"unzip {dst_zip_path} err") music_voice_pure_model = os.path.join(dst_model_dir, "voice_005_rec_v5.pth") music_voice_no_pure_model = os.path.join(dst_model_dir, "voice_10_v5.pth") gender_pure_model = os.path.join(dst_model_dir, "gender_8k_ratev5_v6_adam.pth") gender_no_pure_model = os.path.join(dst_model_dir, "gender_8k_v6_adam.pth") vc = VoiceClass(music_voice_pure_model, music_voice_no_pure_model, gender_pure_model, gender_no_pure_model) return vc def init_svc_model(): meisheng_env_prepare(logging, gs_model_dir) embed_model, hubert_model = load_model() cs_sim = cos_similar() return embed_model, hubert_model, cs_sim def download_volume_adjustment(): """ 下载音量调整工具 :return: """ volume_bin_path = os.path.join(gs_model_dir, "ebur128_tool") if not os.path.exists(volume_bin_path): if not download2disk(gs_volume_bin_url, volume_bin_path): svc_offline_logger.fatal(f"download volume_bin err={gs_volume_bin_url}") os.system(f"chmod +x {volume_bin_path}") def volume_adjustment(wav_path, target_loudness, out_path): """ 音量调整 :param wav_path: :param target_loudness: :param out_path: :return: """ volume_bin_path = os.path.join(gs_model_dir, "ebur128_tool") cmd = f"{volume_bin_path} {wav_path} {target_loudness} {out_path}" os.system(cmd) class SVCOnline: def __init__(self): st = time.time() self.gender_model = init_gender_model() self.embed_model, self.hubert_model, self.cs_sim = init_svc_model() download_volume_adjustment() download_volume_balanced() svc_offline_logger.info(f"svc init finished, sp = {time.time() - st}") def gender_process(self, worker_attr): st = time.time() gender, female_rate, is_pure = self.gender_model.process(worker_attr.vocal_path) + svc_offline_logger.info( f"{worker_attr.vocal_url}, gender={gender}, female_rate={female_rate}, is_pure={is_pure}, " f"gender_process sp = {time.time() - st}") if gender == 0: gender = 'female' elif gender == 1: gender = 'male' elif female_rate == None: gender = 'male' return gender, gs_err_code_gender_classify elif female_rate > 0.5: gender = 'female' else: gender = 'male' + if gender == 'female': + if self.gender_model.vocal_ratio < 0.5: + print(f"@@@ vocal_ratio: {self.gender_model.vocal_ratio}, gender : {gender}, gs_err_code_vocal_ratio : {gs_err_code_vocal_ratio}") + return gender, gs_err_code_vocal_ratio + else: + if self.gender_model.vocal_ratio < 0.6: + print(f"@@@ vocal_ratio: {self.gender_model.vocal_ratio}, gender : {gender}, gs_err_code_vocal_ratio : {gs_err_code_vocal_ratio}") + return gender, gs_err_code_vocal_ratio + svc_offline_logger.info(f"{worker_attr.vocal_url}, modified gender={gender}") # err = gs_err_code_success # if female_rate == -1: # err = gs_err_code_target_silence return gender, gs_err_code_success def process(self, worker_attr): gender, err = self.gender_process(worker_attr) if err != gs_err_code_success: return gender, err song_path = worker_attr.female_svc_source_path if gender == "male": song_path = worker_attr.male_svc_source_path params = {'gender': gender, 'tst': worker_attr.st_tm, "tnd": worker_attr.ed_tm, 'delay': 0, 'song_path': None} st = time.time() err_code = process_svc_online(song_path, worker_attr.vocal_path, worker_attr.target_wav_path, self.embed_model, self.hubert_model, self.cs_sim, params) svc_offline_logger.info(f"{worker_attr.vocal_url}, err_code={err_code} process svc sp = {time.time() - st}") return gender, err_code diff --git a/AIMeiSheng/voice_classification/online/voice_class_online_fang.py b/AIMeiSheng/voice_classification/online/voice_class_online_fang.py index d792db0..a74173d 100644 --- a/AIMeiSheng/voice_classification/online/voice_class_online_fang.py +++ b/AIMeiSheng/voice_classification/online/voice_class_online_fang.py @@ -1,444 +1,446 @@ """ 男女声分类在线工具 1 转码为16bit单声道 2 均衡化 3 模型分类 """ import os import sys import librosa import shutil import logging import time import torch.nn.functional as F import numpy as np from model import * # from common import bind_kernel logging.basicConfig(level=logging.INFO) os.environ["LRU_CACHE_CAPACITY"] = "1" # torch.set_num_threads(1) # bind_kernel(1) """ 临时用一下,全局使用的变量 """ transcode_time = 0 vb_time = 0 mfcc_time = 0 predict_time = 0 """ 错误码 """ ERR_CODE_SUCCESS = 0 # 处理成功 ERR_CODE_NO_FILE = -1 # 文件不存在 ERR_CODE_TRANSCODE = -2 # 转码失败 ERR_CODE_VOLUME_BALANCED = -3 # 均衡化失败 ERR_CODE_FEATURE_TOO_SHORT = -4 # 特征文件太短 """ 常量 """ FRAME_LEN = 128 MFCC_LEN = 80 gs_bin_url = "https://av-audit-sync-sg-1256122840.cos.ap-singapore.myqcloud.com/hub/voice_classification/bin/bin.zip" EBUR128_BIN = "/tmp/voice_class_bin/standard_audio_no_cut" # EBUR128_BIN = "/Users/yangjianli/linux/opt/soft/bin/standard_audio_no_cut" GENDER_FEMALE = 0 GENDER_MALE = 1 GENDER_OTHER = 2 """ 通用函数 """ def exec_cmd(cmd): ret = os.system(cmd) if ret != 0: return False return True """ 业务需要的函数 """ def get_one_mfcc(file_url): st = time.time() data, sr = librosa.load(file_url, sr=16000) if len(data) < 512: return [] mfcc = librosa.feature.mfcc(y=data, sr=sr, n_fft=512, hop_length=256, n_mfcc=MFCC_LEN) mfcc = mfcc.transpose() print("get_one_mfcc:spend_time={}".format(time.time() - st)) global mfcc_time mfcc_time += time.time() - st return mfcc def download_volume_balanced(): import urllib.request if not os.path.exists(EBUR128_BIN): dst_path = "/tmp/bin.zip" urllib.request.urlretrieve(gs_bin_url, dst_path) if not os.path.exists(dst_path): print(f"download dst_path={gs_bin_url} err!") exit(-1) dirname = os.path.dirname(dst_path) cmd = f"cd {dirname}; unzip bin.zip; rm -f bin.zip; mv bin voice_class_bin" os.system(cmd) if not os.path.exists(EBUR128_BIN): print(f"exec {cmd} err!") exit(-1) def volume_balanced(src, dst): st = time.time() download_volume_balanced() cmd = "{} {} {}".format(EBUR128_BIN, src, dst) logging.info(cmd) exec_cmd(cmd) if not os.path.exists(dst): logging.error("volume_balanced:cmd={}".format(cmd)) print("volume_balanced:spend_time={}".format(time.time() - st)) global vb_time vb_time += time.time() - st return os.path.exists(dst) def transcode(src, dst): st = time.time() cmd = "ffmpeg -loglevel quiet -i {} -ar 16000 -ac 1 {}".format(src, dst) logging.info(cmd) exec_cmd(cmd) if not os.path.exists(dst): logging.error("transcode:cmd={}".format(cmd)) print("transcode:spend_time={}".format(time.time() - st)) global transcode_time transcode_time += time.time() - st return os.path.exists(dst) class VoiceClass: def __init__(self, music_voice_pure_model, music_voice_no_pure_model, gender_pure_model, gender_no_pure_model): """ 四个模型 :param music_voice_pure_model: 分辨纯净人声/其他 :param music_voice_no_pure_model: 分辨有人声/其他 :param gender_pure_model: 纯净人声分辨男女 :param gender_no_pure_model: 有人声分辨男女 """ st = time.time() self.device = "cpu" self.batch_size = 256 self.music_voice_pure_model = load_model(MusicVoiceV5Model, music_voice_pure_model, self.device) self.music_voice_no_pure_model = load_model(MusicVoiceV5Model, music_voice_no_pure_model, self.device) self.gender_pure_model = load_model(MobileNetV2Gender, gender_pure_model, self.device) self.gender_no_pure_model = load_model(MobileNetV2Gender, gender_no_pure_model, self.device) logging.info("load model ok ! spend_time={}".format(time.time() - st)) + self.vocal_ratio = 1 def batch_predict(self, model, features): st = time.time() scores = [] with torch.no_grad(): for i in range(0, len(features), self.batch_size): cur_data = features[i:i + self.batch_size].to(self.device) predicts = model(cur_data) predicts_score = F.softmax(predicts, dim=1) scores.extend(predicts_score.cpu().numpy()) ret = np.array(scores) global predict_time predict_time += time.time() - st return ret def predict_pure(self, filename, features): scores = self.batch_predict(self.music_voice_pure_model, features) new_features = [] for idx, score in enumerate(scores): if score[0] > 0.5: # 非人声 continue new_features.append(features[idx].numpy()) # 人声段太少,不能进行处理 # 参数可以改 new_feature_len = len(new_features) new_feature_rate = len(new_features) / len(features) + self.vocal_ratio = new_feature_rate # 修改人声占比低一些 if new_feature_len < 4 or new_feature_rate < 0.05: logging.warning( "filename={}|predict_pure|other|len={}|rate={}".format(filename, new_feature_len, new_feature_rate) ) return GENDER_OTHER, -1 new_features = torch.from_numpy(np.array(new_features)) scores = self.batch_predict(self.gender_pure_model, new_features) f_avg = sum(scores[:, 0]) / len(scores) m_avg = sum(scores[:, 1]) / len(scores) female_rate = f_avg / (f_avg + m_avg) if female_rate > 0.65: return GENDER_FEMALE, female_rate if female_rate < 0.12: return GENDER_MALE, female_rate logging.warning( "filename={}|predict_pure|other|len={}|rate={}".format(filename, new_feature_len, new_feature_rate) ) return GENDER_OTHER, female_rate def predict_no_pure(self, filename, features): scores = self.batch_predict(self.music_voice_no_pure_model, features) new_features = [] for idx, score in enumerate(scores): if score[0] > 0.5: # 非人声 continue new_features.append(features[idx].numpy()) # 人声段太少,不能进行处理 # 参数可以改 new_feature_len = len(new_features) new_feature_rate = len(new_features) / len(features) # 修改人声占比低一些 if new_feature_len < 4 or new_feature_rate < 0.05: logging.warning( "filename={}|predict_no_pure|other|len={}|rate={}".format(filename, new_feature_len, new_feature_rate) ) return GENDER_OTHER, -1 new_features = torch.from_numpy(np.array(new_features)) scores = self.batch_predict(self.gender_no_pure_model, new_features) f_avg = sum(scores[:, 0]) / len(scores) m_avg = sum(scores[:, 1]) / len(scores) female_rate = f_avg / (f_avg + m_avg) if female_rate > 0.75: return GENDER_FEMALE, female_rate if female_rate < 0.1: return GENDER_MALE, female_rate logging.warning( "filename={}|predict_no_pure|other|len={}|rate={}".format(filename, new_feature_len, new_feature_rate) ) return GENDER_OTHER, female_rate def predict(self, filename, features): st = time.time() new_features = [] for i in range(FRAME_LEN, len(features), FRAME_LEN): new_features.append(features[i - FRAME_LEN: i]) new_features = torch.from_numpy(np.array(new_features)) gender, rate = self.predict_pure(filename, new_features) if gender == GENDER_OTHER: logging.info("start no pure process...") gender, rate = self.predict_no_pure(filename, new_features) return gender, rate, False print("predict|spend_time={}".format(time.time() - st)) return gender, rate, True def process_one_logic(self, filename, file_path, cache_dir): tmp_wav = os.path.join(cache_dir, "tmp.wav") tmp_vb_wav = os.path.join(cache_dir, "tmp_vb.wav") if not transcode(file_path, tmp_wav): return ERR_CODE_TRANSCODE, None, None if not volume_balanced(tmp_wav, tmp_vb_wav): return ERR_CODE_VOLUME_BALANCED, None, None features = get_one_mfcc(tmp_vb_wav) if len(features) < FRAME_LEN: logging.error("feature too short|file_path={}".format(file_path)) return ERR_CODE_FEATURE_TOO_SHORT, None, None return self.predict(filename, features) def process_one(self, file_path): base_dir = os.path.dirname(file_path) filename = os.path.splitext(file_path)[0] print("filename:", filename) cache_dir = os.path.join(base_dir, filename + "_cache") if os.path.exists(cache_dir): shutil.rmtree(cache_dir) os.makedirs(cache_dir) ret = self.process_one_logic(filename, file_path, cache_dir) shutil.rmtree(cache_dir) return ret def process(self, file_path): gender, female_rate, is_pure = self.process_one(file_path) logging.info("{}|gender={}|female_rate={}".format(file_path, gender, female_rate)) return gender, female_rate, is_pure def process_by_feature(self, feature_file): """ 直接处理特征文件 :param feature_file: :return: """ filename = os.path.splitext(feature_file)[0] features = np.load(feature_file) gender, female_rate = self.predict(filename, features) return gender, female_rate def test_all_feature(): import glob base_dir = "/data/datasets/music_voice_dataset_full/feature_online_data_v3" female = glob.glob(os.path.join(base_dir, "female/*feature.npy")) male = glob.glob(os.path.join(base_dir, "male/*feature.npy")) other = glob.glob(os.path.join(base_dir, "other/*feature.npy")) model_path = "/data/jianli.yang/voice_classification/online/models" music_voice_pure_model = os.path.join(model_path, "voice_005_rec_v5.pth") music_voice_no_pure_model = os.path.join(model_path, "voice_10_v5.pth") gender_pure_model = os.path.join(model_path, "gender_8k_ratev5_v6_adam.pth") gender_no_pure_model = os.path.join(model_path, "gender_8k_v6_adam.pth") vc = VoiceClass(music_voice_pure_model, music_voice_no_pure_model, gender_pure_model, gender_no_pure_model) tot_st = time.time() ret_map = { 0: {0: 0, 1: 0, 2: 0}, 1: {0: 0, 1: 0, 2: 0}, 2: {0: 0, 1: 0, 2: 0} } for file in female: st = time.time() print("------------------------------>>>>>") gender, female_score = vc.process_by_feature(file) ret_map[0][gender] += 1 if gender != 0: print("err:female->{}|{}|{}".format(gender, file, female_score)) print("process|spend_tm=={}".format(time.time() - st)) for file in male: st = time.time() print("------------------------------>>>>>") gender, female_score = vc.process_by_feature(file) ret_map[1][gender] += 1 if gender != 1: print("err:male->{}|{}|{}".format(gender, file, female_score)) print("process|spend_tm=={}".format(time.time() - st)) for file in other: st = time.time() print("------------------------------>>>>>") gender, female_score = vc.process_by_feature(file) ret_map[2][gender] += 1 if gender != 2: print("err:other->{}|{}|{}".format(gender, file, female_score)) print("process|spend_tm=={}".format(time.time() - st)) global transcode_time, vb_time, mfcc_time, predict_time print("spend_time:tot={}|transcode={}|vb={}|gen_feature={}|predict={}".format(time.time() - tot_st, transcode_time, vb_time, mfcc_time, predict_time)) f_f = ret_map[0][0] f_m = ret_map[0][1] f_o = ret_map[0][2] m_f = ret_map[1][0] m_m = ret_map[1][1] m_o = ret_map[1][2] o_f = ret_map[2][0] o_m = ret_map[2][1] o_o = ret_map[2][2] print("ff:{},fm:{},fo:{}".format(f_f, f_m, f_o)) print("mm:{},mf:{},mo:{}".format(m_m, m_f, m_o)) print("om:{},of:{},oo:{}".format(o_m, o_f, o_o)) # 女性准确率和召回率 f_acc = f_f / (f_f + m_f + o_f) f_recall = f_f / (f_f + f_m + f_o) # 男性准确率和召回率 m_acc = m_m / (m_m + f_m + o_m) m_recall = m_m / (m_m + m_f + m_o) print("female: acc={}|recall={}".format(f_acc, f_recall)) print("male: acc={}|recall={}".format(m_acc, m_recall)) def test_all(): import glob base_dir = "/data/datasets/music_voice_dataset_full/online_data_v3_top200" female = glob.glob(os.path.join(base_dir, "female/*mp4")) male = glob.glob(os.path.join(base_dir, "male/*mp4")) other = glob.glob(os.path.join(base_dir, "other/*mp4")) model_path = "/data/jianli.yang/voice_classification/online/models" music_voice_pure_model = os.path.join(model_path, "voice_005_rec_v5.pth") music_voice_no_pure_model = os.path.join(model_path, "voice_10_v5.pth") gender_pure_model = os.path.join(model_path, "gender_8k_ratev5_v6_adam.pth") gender_no_pure_model = os.path.join(model_path, "gender_8k_v6_adam.pth") vc = VoiceClass(music_voice_pure_model, music_voice_no_pure_model, gender_pure_model, gender_no_pure_model) tot_st = time.time() ret_map = { 0: {0: 0, 1: 0, 2: 0}, 1: {0: 0, 1: 0, 2: 0}, 2: {0: 0, 1: 0, 2: 0} } for file in female: st = time.time() print("------------------------------>>>>>") gender, female_score = vc.process(file) ret_map[0][gender] += 1 if gender != 0: print("err:female->{}|{}|{}".format(gender, file, female_score)) print("process|spend_tm=={}".format(time.time() - st)) for file in male: st = time.time() print("------------------------------>>>>>") gender, female_score = vc.process(file) ret_map[1][gender] += 1 if gender != 1: print("err:male->{}|{}|{}".format(gender, file, female_score)) print("process|spend_tm=={}".format(time.time() - st)) for file in other: st = time.time() print("------------------------------>>>>>") gender, female_score = vc.process(file) ret_map[2][gender] += 1 if gender != 2: print("err:other->{}|{}|{}".format(gender, file, female_score)) print("process|spend_tm=={}".format(time.time() - st)) global transcode_time, vb_time, mfcc_time, predict_time print("spend_time:tot={}|transcode={}|vb={}|gen_feature={}|predict={}".format(time.time() - tot_st, transcode_time, vb_time, mfcc_time, predict_time)) f_f = ret_map[0][0] f_m = ret_map[0][1] f_o = ret_map[0][2] m_f = ret_map[1][0] m_m = ret_map[1][1] m_o = ret_map[1][2] o_f = ret_map[2][0] o_m = ret_map[2][1] o_o = ret_map[2][2] print("ff:{},fm:{},fo:{}".format(f_f, f_m, f_o)) print("mm:{},mf:{},mo:{}".format(m_m, m_f, m_o)) print("om:{},of:{},oo:{}".format(o_m, o_f, o_o)) # 女性准确率和召回率 f_acc = f_f / (f_f + m_f + o_f) f_recall = f_f / (f_f + f_m + f_o) # 男性准确率和召回率 m_acc = m_m / (m_m + f_m + o_m) m_recall = m_m / (m_m + m_f + m_o) print("female: acc={}|recall={}".format(f_acc, f_recall)) print("male: acc={}|recall={}".format(m_acc, m_recall)) if __name__ == "__main__": # test_all() # test_all_feature() model_path = sys.argv[1] voice_path = sys.argv[2] music_voice_pure_model = os.path.join(model_path, "voice_005_rec_v5.pth") music_voice_no_pure_model = os.path.join(model_path, "voice_10_v5.pth") gender_pure_model = os.path.join(model_path, "gender_8k_ratev5_v6_adam.pth") gender_no_pure_model = os.path.join(model_path, "gender_8k_v6_adam.pth") vc = VoiceClass(music_voice_pure_model, music_voice_no_pure_model, gender_pure_model, gender_no_pure_model) for i in range(0, 1): st = time.time() print("------------------------------>>>>>") gender, female_rate, is_pure = vc.process(voice_path) print("process|spend_tm=={}".format(time.time() - st)) print("gender:{}, female_rate:{},is_pure:{}".format(gender, female_rate, is_pure))