Page MenuHomePhabricator

music_gender_class_val_v3.py
No OneTemporary

music_gender_class_val_v3.py

"""
使用两个模型来验证歌曲级别的准确率和召回率
生成出结果之后需要使用script/music_voice_class/ana中的代码 v3 代码进行结果分析
"""
import os
import sys
import glob
import numpy as np
import psutil
import time
import torch.nn.functional
os.environ["LRU_CACHE_CAPACITY"] = "1"
FRAME_LEN = 128
MFCC_LEN = 80
import music_voice_models
import music_gender_models_simple
def get_current_memory_gb():
# 获取当前进程内存占用。
pid = os.getpid()
p = psutil.Process(pid)
info = p.memory_full_info()
print("cur memory=:{} M".format(info.uss / 1024 / 1024))
class PredictModel:
"""
测试一下模型的效果,将错误的直接输出出来
"""
def __init__(self, model_path, model2_path, model3_path, features_dir):
self.device = 'cuda'
model = music_voice_models.get_models("v5")()
params = torch.load(model_path)
model.load_state_dict(state_dict=params)
model.eval()
model2 = music_voice_models.get_models("v5")()
params = torch.load(model2_path)
model2.load_state_dict(state_dict=params)
model2.eval()
model3 = music_gender_models_simple.get_models("v5")()
params3 = torch.load(model3_path)
model3.load_state_dict(state_dict=params3)
model3.eval()
self.model = model # 纯人声/其他
self.model2 = model2 # 带有人声/其他
self.model3 = model3 # 男女声
self.model.to(self.device)
self.model2.to(self.device)
self.model3.to(self.device)
self.frame_num = FRAME_LEN
self.batch_size = 128
self.features_dir = features_dir
self._female_files = glob.glob(os.path.join(feature_dir, "female/*.feature.npy")) # 女_0
self._male_files = glob.glob(os.path.join(features_dir, "male/*.feature.npy")) # 男_1
self._other_files = glob.glob(os.path.join(features_dir, "other/*.feature.npy")) # 其他2
def process_one(self, file, gender):
# 构建数据
mfccs = np.load(file)
data = []
for i in range(FRAME_LEN, len(mfccs), 128): # 间隔稍微宽一点,减少计算量
data.append(mfccs[i - FRAME_LEN:i])
data = torch.from_numpy(np.array(data))
print("load data ok.... shape={}".format(data.shape))
# 预测
female_num = 0
male_num = 0
other_num = 0
female_sm = []
male_sm = []
# filename, gender, idx, female_score, male_score, other_score
ret_msg = []
with torch.no_grad():
batch_size = 256
for i in range(0, len(data), batch_size):
cur_data = data[i:i + batch_size].to(self.device)
predicts = self.model(cur_data)
predicts_score = torch.nn.functional.softmax(predicts, dim=1)
_, predicts = predicts.max(dim=1)
predicts2 = self.model2(cur_data)
predicts_score2 = torch.nn.functional.softmax(predicts2, dim=1)
_, predicts2 = predicts2.max(dim=1)
predicts3 = self.model3(cur_data)
predicts_score3 = torch.nn.functional.softmax(predicts3, dim=1)
_, predicts3 = predicts3.max(dim=1)
print("predict ok...")
# 统计结果
for j in range(len(predicts)):
ret_msg.append(
"{},{},{},{},{},{},{},{},{}".format(file, gender, i + j, predicts_score[j][0],
predicts_score[j][1], predicts_score2[j][0],
predicts_score2[j][1],
predicts_score3[j][0],
predicts_score3[j][1],
))
male_sm.append(predicts_score2[j][1])
female_sm.append(predicts_score2[j][0])
if predicts2[j] == 0:
female_num += 1
if predicts2[j] == 1:
male_num += 1
if predicts2[j] == 2:
other_num += 1
print("calc ok...")
print("{},{}".format(sum(female_sm) / len(female_sm), sum(male_sm) / len(male_sm)))
print("torch {},{},{}....".format(female_num, male_num, other_num))
# 占比超过一半,则判定为男/女,否则不确定
tot = female_num + male_num + other_num
if female_num / tot > 0.5:
return 0, ret_msg
if male_num / tot > 0.5:
return 1, ret_msg
return 2, ret_msg
def process_files(self, files, gender, log_file):
# 处理女声的结果
f_num = 0
m_num = 0
o_num = 0
for file in files:
ret, ret_msg = self.process_one(file, gender)
print("file_name={} ret={}".format(file, ret))
if ret == 0:
f_num += 1
elif ret == 1:
m_num += 1
else:
o_num += 1
# 追加写入到文件
with open(log_file, "a") as f:
for line in ret_msg:
f.write(line + "\n")
print("f_num={}, m_num={}, o_num={}".format(f_num, m_num, o_num))
return f_num, m_num, o_num
def process(self, log_file):
f_num, m_num, o_num = self.process_files(self._female_files, 0, log_file)
f_num1, m_num1, o_num1 = self.process_files(self._male_files, 1, log_file)
self.process_files(self._other_files, 2, log_file)
# 对于女声
f_acc = f_num / (f_num + f_num1)
f_recall = f_num / len(self._female_files)
print("f_male= acc={} recall={}".format(f_acc, f_recall))
# 对于男声
m_acc = m_num1 / (m_num + m_num1)
m_recall = m_num1 / len(self._male_files)
print("m_male= acc={} recall={}".format(m_acc, m_recall))
def process_one(model_dir, model_dir2, model_dir3, filepath, log_file):
pm = PredictModel(model_dir, model_dir2, model_dir3, "")
pm.process_files([filepath], 0, log_file)
if __name__ == "__main__":
model_dir = sys.argv[1]
model_dir2 = sys.argv[2]
model_dir3 = sys.argv[3]
feature_dir = sys.argv[4]
log_file = sys.argv[5]
mode = sys.argv[6]
if mode == "one":
if os.path.exists(log_file):
os.unlink(log_file)
process_one(model_dir, model_dir2, model_dir3, feature_dir, log_file)
else:
pm = PredictModel(model_dir, model_dir2, model_dir3, feature_dir)
if os.path.exists(log_file):
os.unlink(log_file)
pm.process(log_file)

File Metadata

Mime Type
text/plain
Expires
Sun, Nov 24, 18:26 (20 h, 47 m)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
1326398
Default Alt Text
music_gender_class_val_v3.py (6 KB)

Event Timeline