diff --git a/AIMeiSheng/myinfer_multi_spk_embed_in_dec_diff_fi_meisheng.py b/AIMeiSheng/myinfer_multi_spk_embed_in_dec_diff_fi_meisheng.py index b68dc3b..f1da5a9 100644 --- a/AIMeiSheng/myinfer_multi_spk_embed_in_dec_diff_fi_meisheng.py +++ b/AIMeiSheng/myinfer_multi_spk_embed_in_dec_diff_fi_meisheng.py @@ -1,215 +1,217 @@ import os,sys,pdb,torch now_dir = os.getcwd() sys.path.append(now_dir) import argparse import glob import sys import torch from multiprocessing import cpu_count class Config: def __init__(self,device,is_half): self.device = device self.is_half = is_half self.n_cpu = 0 self.gpu_name = None self.gpu_mem = None self.x_pad, self.x_query, self.x_center, self.x_max = self.device_config() def device_config(self) -> tuple: + current_dir = os.path.dirname(os.path.abspath(__file__)) + config_path = os.path.join(current_dir, "configs") if torch.cuda.is_available(): i_device = int(self.device.split(":")[-1]) self.gpu_name = torch.cuda.get_device_name(i_device) if ( ("16" in self.gpu_name and "V100" not in self.gpu_name.upper()) or "P40" in self.gpu_name.upper() or "1060" in self.gpu_name or "1070" in self.gpu_name or "1080" in self.gpu_name ): print("16系/10系显卡和P40强制单精度") self.is_half = False for config_file in ["32k.json", "40k.json", "48k.json"]: - with open(f"configs/{config_file}", "r") as f: + with open(f"{config_path}/{config_file}", "r") as f: strr = f.read().replace("true", "false") - with open(f"configs/{config_file}", "w") as f: + with open(f"{config_path}/{config_file}", "w") as f: f.write(strr) - with open("trainset_preprocess_pipeline_print.py", "r") as f: + with open(f"{current_dir}/trainset_preprocess_pipeline_print.py", "r") as f: strr = f.read().replace("3.7", "3.0") - with open("trainset_preprocess_pipeline_print.py", "w") as f: + with open(f"{current_dir}/trainset_preprocess_pipeline_print.py", "w") as f: f.write(strr) else: self.gpu_name = None self.gpu_mem = int( torch.cuda.get_device_properties(i_device).total_memory / 1024 / 1024 / 1024 + 0.4 ) if self.gpu_mem <= 4: - with open("trainset_preprocess_pipeline_print.py", "r") as f: + with open(f"{current_dir}/trainset_preprocess_pipeline_print.py", "r") as f: strr = f.read().replace("3.7", "3.0") - with open("trainset_preprocess_pipeline_print.py", "w") as f: + with open(f"{current_dir}/trainset_preprocess_pipeline_print.py", "w") as f: f.write(strr) elif torch.backends.mps.is_available(): print("没有发现支持的N卡, 使用MPS进行推理") self.device = "mps" else: print("没有发现支持的N卡, 使用CPU进行推理") self.device = "cpu" self.is_half = True if self.n_cpu == 0: self.n_cpu = cpu_count() if self.is_half: # 6G显存配置 x_pad = 3 x_query = 10 x_center = 80 #60 x_max = 85#65 else: # 5G显存配置 x_pad = 1 x_query = 6 x_center = 38 x_max = 41 if self.gpu_mem != None and self.gpu_mem <= 4: x_pad = 1 x_query = 5 x_center = 30 x_max = 32 return x_pad, x_query, x_center, x_max index_path="./logs/xusong_v2_org_version_multispk_charlie_puth_embed_in_dec_muloss_show/added_IVF614_Flat_nprobe_1_xusong_v2_org_version_multispk_charlie_puth_embed_in_dec_show_v2.index" # f0method="rmvpe" #harvest or pm index_rate=float("0.0") #index rate device="cuda:0" is_half=True filter_radius=int(3) ##3 resample_sr=int(0) # 0 rms_mix_rate=float(1) # rms混合比例 1,不等于1混合 protect=float(0.33 )## ??? 0.33 fang #print(sys.argv) config=Config(device,is_half) now_dir=os.getcwd() sys.path.append(now_dir) from vc_infer_pipeline_org_embed import VC from lib.infer_pack.models_embed_in_dec_diff_fi import ( SynthesizerTrnMs256NSFsid, SynthesizerTrnMs256NSFsid_nono, SynthesizerTrnMs768NSFsid, SynthesizerTrnMs768NSFsid_nono, ) from lib.audio import load_audio from fairseq import checkpoint_utils from scipy.io import wavfile from AIMeiSheng.docker_demo.common import gs_hubert_model_path # hubert_model=None def load_hubert(): # global hubert_model models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task([gs_hubert_model_path],suffix="",) #models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(["checkpoint_best_legacy_500.pt"],suffix="",) hubert_model = models[0] hubert_model = hubert_model.to(device) if(is_half):hubert_model = hubert_model.half() else:hubert_model = hubert_model.float() hubert_model.eval() return hubert_model def vc_single(sid,input_audio,f0_up_key,f0_file,f0_method,file_index,index_rate,hubert_model,paras): global tgt_sr,net_g,vc,version if input_audio is None:return "You need to upload an audio", None f0_up_key = int(f0_up_key) # print("@@xxxf0_up_key:",f0_up_key) audio = load_audio(input_audio,16000) if paras != None: st = int(paras['tst'] * 16000/1000) en = len(audio) if paras['tnd'] != None: en = min(en,int(paras['tnd'] * 16000/1000)) audio = audio[st:en] times = [0, 0, 0] if(hubert_model==None): hubert_model = load_hubert() if_f0 = cpt.get("f0", 1) audio_opt=vc.pipeline_mulprocess(hubert_model,net_g,sid,audio,input_audio,times,f0_up_key,f0_method,file_index,index_rate,if_f0,filter_radius,tgt_sr,resample_sr,rms_mix_rate,version,protect,f0_file=f0_file) #print(times) #print("@@using multi process") return audio_opt def get_vc_core(model_path,is_half): #print("loading pth %s" % model_path) cpt = torch.load(model_path, map_location="cpu") tgt_sr = cpt["config"][-1] cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] if_f0 = cpt.get("f0", 1) version = cpt.get("version", "v1") if version == "v1": if if_f0 == 1: net_g = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=is_half) else: net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"]) elif version == "v2": if if_f0 == 1: # net_g = SynthesizerTrnMs768NSFsid(*cpt["config"], is_half=is_half) else: net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"]) #print("load model finished") del net_g.enc_q net_g.load_state_dict(cpt["weight"], strict=False) #print("load net_g finished") return tgt_sr,net_g,cpt,version def get_vc1(model_path,is_half): tgt_sr, net_g, cpt, version = get_vc_core(model_path, is_half) net_g.eval().to(device) if (is_half):net_g = net_g.half() else:net_g = net_g.float() vc = VC(tgt_sr, config) n_spk=cpt["config"][-3] return def get_rmvpe(model_path="rmvpe.pt"): from lib.rmvpe import RMVPE global f0_method #print("loading rmvpe model") f0_method = RMVPE(model_path, is_half=True, device='cuda') return f0_method def get_vc(model_path): global n_spk,tgt_sr,net_g,vc,cpt,device,is_half,version tgt_sr, net_g, cpt, version = get_vc_core(model_path, is_half) net_g.eval().to(device) if (is_half):net_g = net_g.half() else:net_g = net_g.float() vc = VC(tgt_sr, config) n_spk=cpt["config"][-3] # return {"visible": True,"maximum": n_spk, "__type__": "update"} # return net_g def svc_main(input_path,opt_path,sid_embed,f0up_key=0,hubert_model=None, paras=None): #print("sid_embed: ",sid_embed) wav_opt = vc_single(sid_embed,input_path,f0up_key,None,f0_method,index_path,index_rate,hubert_model,paras) #print("out_path: ",opt_path) wavfile.write(opt_path, tgt_sr, wav_opt) diff --git a/AutoCoverTool/script/common.py b/AutoCoverTool/script/common.py index ab4b5a4..cb472a0 100644 --- a/AutoCoverTool/script/common.py +++ b/AutoCoverTool/script/common.py @@ -1,181 +1,181 @@ # -*-encoding=utf8-*- import time import pymysql import logging import pandas as pd # from impala.dbapi import connect # from sqlalchemy import create_engine # from sqlalchemy.types import NVARCHAR, Float, Integer banned_user_map = { "host": "sg-songbook00.db.starmaker.co", "user": "worker", "passwd": "gRYppQtdTpP3nFzH", "db": "starmaker" } banned_user_map_v1 = { "host": "sg-starmaker-device-r2.db.starmaker.co", "user": "worker", "passwd": "gRYppQtdTpP3nFzH", "db": "mis" } banned_user_map_v2 = { "host": "sg-sm-img-r1.starmaker.co", "user": "worker", "passwd": "gRYppQtdTpP3nFzH", "db": "sm" } # 做一下shared库的查询依赖 shard_map = { "shard_sm_12": "sg-shard02-r2.db.starmaker.co", "shard_sm_13": "sg-shard02-r2.db.starmaker.co", "shard_sm_14": "sg-shard02-r2.db.starmaker.co", "shard_sm_15": "sg-shard02-r2.db.starmaker.co", "shard_sm_30": "sg-shard02-r2.db.starmaker.co", "shard_sm_31": "sg-shard02-r2.db.starmaker.co", "shard_sm_20": "sg-shard02-r2.db.starmaker.co", "shard_sm_21": "sg-shard02-r2.db.starmaker.co", "shard_sm_22": "sg-shard03-r2.db.starmaker.co", "shard_sm_23": "sg-shard03-r2.db.starmaker.co", "shard_sm_24": "sg-shard03-r2.db.starmaker.co", "shard_sm_25": "sg-shard03-r2.db.starmaker.co", "shard_sm_26": "sg-shard03-r2.db.starmaker.co", "shard_sm_27": "sg-shard03-r2.db.starmaker.co", "shard_sm_28": "sg-shard03-r2.db.starmaker.co", "shard_sm_29": "sg-shard03-r2.db.starmaker.co", "shard_sm_0": "sg-shard00-r2.db.starmaker.co", "shard_sm_1": "sg-shard00-r2.db.starmaker.co", "shard_sm_2": "sg-shard00-r2.db.starmaker.co", "shard_sm_3": "sg-shard00-r2.db.starmaker.co", "shard_sm_4": "sg-shard00-r2.db.starmaker.co", "shard_sm_5": "sg-shard00-r2.db.starmaker.co", "shard_sm_16": "sg-shard00-r2.db.starmaker.co", "shard_sm_17": "sg-shard00-r2.db.starmaker.co", "shard_sm_6": "sg-shard01-r2.db.starmaker.co", "shard_sm_7": "sg-shard01-r2.db.starmaker.co", "shard_sm_8": "sg-shard01-r2.db.starmaker.co", "shard_sm_9": "sg-shard01-r2.db.starmaker.co", "shard_sm_10": "sg-shard01-r2.db.starmaker.co", "shard_sm_11": "sg-shard01-r2.db.starmaker.co", "shard_sm_18": "sg-shard01-r2.db.starmaker.co", "shard_sm_19": "sg-shard01-r2.db.starmaker.co", "shard_sm_32": "sg-shard04-r2.db.starmaker.co", "shard_sm_33": "sg-shard04-r2.db.starmaker.co", "shard_sm_34": "sg-shard04-r2.db.starmaker.co", "shard_sm_35": "sg-shard04-r2.db.starmaker.co", "shard_sm_36": "sg-shard04-r2.db.starmaker.co", "shard_sm_37": "sg-shard04-r2.db.starmaker.co", "shard_sm_38": "sg-shard04-r2.db.starmaker.co", "shard_sm_39": "sg-shard04-r2.db.starmaker.co", "shard_sm_40": "sg-shard05-r2.db.starmaker.co", "shard_sm_41": "sg-shard05-r2.db.starmaker.co", "shard_sm_42": "sg-shard05-r2.db.starmaker.co", "shard_sm_43": "sg-shard05-r2.db.starmaker.co", "shard_sm_44": "sg-shard05-r2.db.starmaker.co", "shard_sm_45": "sg-shard05-r2.db.starmaker.co", "shard_sm_46": "sg-shard05-r2.db.starmaker.co", "shard_sm_47": "sg-shard05-r2.db.starmaker.co", "shard_sm_48": "sg-shard05-r2.db.starmaker.co", "shard_sm_49": "sg-shard05-r2.db.starmaker.co", "shard_sm_50": "sg-shard05-r2.db.starmaker.co", "name": "shard_sm_{}", "port": 3306, "user": "readonly", "passwd": "JKw6woZgRXsveegL" } def connect_db(host="research-db-r1.starmaker.co", port=3306, user="root", passwd="Qrdl1130", db=""): - print("connect mysql host={} port={} user={} passwd={} db={}".format(host, port, user, passwd, db)) + # print("connect mysql host={} port={} user={} passwd={} db={}".format(host, port, user, passwd, db)) return pymysql.connect(host=host, port=port, user=user, passwd=passwd, db=db) def get_data_by_mysql(sql, ban=banned_user_map): db = connect_db(host=ban["host"], passwd=ban["passwd"], user=ban["user"], db=ban["db"]) db_cursor = db.cursor() if len(sql) < 100: print("execute = {}".format(sql)) else: print("execute = {}...".format(sql[:100])) db_cursor.execute(sql) res = db_cursor.fetchall() db_cursor.close() db.close() print("res size={}".format(len(res))) return res def get_shard_db(user_id): return int(float(user_id)) >> 48 def get_shard_data_by_sql(sql, user_id): shard_id = get_shard_db(user_id) db_name = shard_map["name"].format(shard_id) host = shard_map[db_name] db = connect_db(host=host, passwd=shard_map["passwd"], user=shard_map["user"], db=db_name) db_cursor = db.cursor() - if len(sql) < 100: - print("execute = {}".format(sql)) - else: - print("execute = {}...".format(sql[:100])) + # if len(sql) < 100: + # print("execute = {}".format(sql)) + # else: + # print("execute = {}...".format(sql[:100])) db_cursor.execute(sql) res = db_cursor.fetchall() db_cursor.close() db.close() - print("res size={}".format(len(res))) + # print("res size={}".format(len(res))) return res # def get_data_by_hql(sql): # logging.info(sql) # ntime = time.time() # conn = connect(host='sg-hive.starmaker.co', port=7001, auth_mechanism='PLAIN', timeout=3600, user="hadoop", # password="7396&pagesize") # cur = conn.cursor() # cur.execute(sql) # data = cur.fetchall() # cur.close() # conn.close() # logging.info("get sql: eps={}".format(time.time() - ntime)) # return data def read_file(in_file): with open(in_file, "r") as f: lines = f.readlines() return lines def write2file(file_path, data): with open(file_path, "w") as f: for line in data: line += "\n" f.write(line) # def map_types(df): # dtypedict = {} # for i, j in zip(df.columns, df.dtypes): # if "object" in str(j): # dtypedict.update({i: NVARCHAR(length=255)}) # if "float" in str(j): # dtypedict.update({i: Float(precision=2, asdecimal=True)}) # if "int" in str(j): # dtypedict.update({i: Integer()}) # return dtypedict # # # def write2db(filename, tablename): # engine = create_engine("mysql+mysqldb://{}:{}@{}/{}".format('root', '', 'localhost:3306', 'starmaker')) # con = engine.connect() # df = pd.read_csv(filename) # dtypedict = map_types(df) # df.to_sql(name=tablename, con=con, if_exists='append', index=False, dtype=dtypedict) \ No newline at end of file diff --git a/AutoCoverTool/script/get_user_recordings.py b/AutoCoverTool/script/get_user_recordings.py index e589dc8..3b92a50 100644 --- a/AutoCoverTool/script/get_user_recordings.py +++ b/AutoCoverTool/script/get_user_recordings.py @@ -1,184 +1,257 @@ """ 获取用户数据 """ import os import time import glob import json import librosa import soundfile from script.common import * def exec_cmd(cmd): r = os.popen(cmd) text = r.read() r.close() return text def get_d(audio_path): cmd = "ffprobe -v quiet -print_format json -show_format -show_streams {}".format(audio_path) data = exec_cmd(cmd) data = json.loads(data) if "format" in data.keys(): if "duration" in data['format']: return float(data["format"]["duration"]) return 0 def get_user_recordings(user_id): sql = "select id, recording_url from recording where user_id={} and created_on > {} and is_public = 1 and is_deleted = 0 and media_type in (1, 2, 3, 4, 9, 10) ".format( user_id, time.time() - 86400 * 30) res = get_shard_data_by_sql(sql, user_id) true_num = 0 for id, url in res: if download_url(url, user_id, str(id)): true_num += 1 if true_num > 15: break def download_url(url, uid, rid): url = str(url).replace("master.mp4", "origin_master.mp4") c_dir = "/data/rsync/jianli.yang/AutoCoverTool/data/train_users/0414_0514/{}".format(uid) if not os.path.exists(c_dir): os.makedirs(c_dir) c_dir = os.path.join(c_dir, "src") if not os.path.exists(c_dir): os.makedirs(c_dir) cmd = "wget {} -O {}/{}.mp4".format(url, c_dir, rid) os.system(cmd) # 转码为44k双声道音频 in_path = os.path.join(c_dir, rid + ".mp4") if os.path.exists(in_path): duration = get_d(in_path) print("duration={}".format(duration)) if duration > 30: dst_path = in_path.replace(".mp4", ".wav") cmd = "ffmpeg -i {} -ar 44100 -ac 1 -y {}".format(in_path, dst_path) print("exec={}".format(cmd)) os.system(cmd) return os.path.exists(dst_path) return False def split_to_idx(ppath, dst_path, user_id): frame_len = 32000 * 15 files = glob.glob(os.path.join(ppath, "*mp4")) mmax = 0 for file in files: try: audio, sr = librosa.load(file, sr=32000, mono=True) except Exception as ex: continue print("audio_len:={}".format(audio.shape)) for i in range(0, len(audio), frame_len): if i + frame_len > len(audio): break cur_data = audio[i:i + frame_len] out_path = os.path.join(dst_path, "{}_{}.wav".format(user_id, mmax)) print("save to {}".format(out_path)) # librosa.output.write_wav(out_path, cur_data, 32000) soundfile.write(out_path, cur_data, 32000, format="wav") mmax += 1 def process(): from online.beanstalk_helper import BeanstalkHelper config = {"addr": "sg-test-common-box-1:11300", "consumer": "auto_cover_tool_download_user"} bean_helper = BeanstalkHelper(config) bean = bean_helper.get_beanstalkd() bean.watch(config["consumer"]) while True: payload = bean.reserve(5) if not payload: logging.info("bean sleep...") continue in_data = json.loads(payload.body) user_id = in_data["user_id"] try: user_id_int = int(float(user_id)) get_user_recordings(in_data["user_id"]) except Exception as ex: pass payload.delete() def put_data(file_path): lines = [] with open(file_path, "r") as f: while True: line = f.readline().strip() if not line: break lines.append(line) from online.beanstalk_helper import BeanstalkHelper config = {"addr": "sg-test-common-box-1:11300", "consumer": "auto_cover_tool_download_user"} bean_helper = BeanstalkHelper(config) for idx, line in enumerate(lines): if idx == 0: continue user_id = line.split(",")[0] message = json.dumps({"user_id": str(user_id)}) bean_helper.put_payload_to_beanstalk(config["consumer"], message) def copy_data(): base_dir = "/data/rsync/jianli.yang/AutoCoverTool/data/train_users/0414_0514" dst_dir = "/data/rsync/jianli.yang/AutoCoverTool/data/train_users/0414_0514_finish" # 只要10首干声以及以上的 dirs = glob.glob(os.path.join(base_dir, "*")) for cur_dir in dirs: cur_name = cur_dir.split("/")[-1] cur_mp4_files = glob.glob(os.path.join(cur_dir, "src/*wav")) if len(cur_mp4_files) > 10: print("mv {} {}".format(cur_dir, os.path.join(dst_dir, cur_name))) +import urllib.request + +url = "https://testtest.mp4" + +try: + status = urllib.request.urlopen(url).code + print(status) +except Exception as err: + print(err) + + +def check_exists(url): + try: + status = urllib.request.urlopen(url).code + return status == 200 + except Exception as err: + return False + + +def get_vocal_url(): + arr = [] + with open("new.txt") as f: + while True: + line = f.readline() + line = line.strip() + if not line: + break + arr.append(line.split("\t")) + + first_line = arr[0][:3].copy() + for i in range(3, len(arr[0])): + first_line.append(arr[0][i]) + first_line.append("vocal_url") + first_line.append("delay_time") + out_arr = [first_line] + + st = time.time() + for i in range(1, 30): + cur_line = arr[0][:3].copy() + female_cnt = 0 + male_cnt = 0 + for j in range(3, len(arr[i])): + cur_line.append(arr[i][j]) + if arr[i][j].isdigit(): + rid = arr[i][j] + sql = f"select id, recording_url, created_on from recording where id={rid}" + res = get_shard_data_by_sql(sql, rid) + if len(res) > 0: + origin_master_url = res[0][1].replace("master.mp4", "origin_master.mp4").replace("recordings", "recordings_origin") + sql = f"select recording_id, delay_time from recording_extra_info where recording_id={rid}" + res1 = get_shard_data_by_sql(sql, rid) + # print(origin_master_url) + if check_exists(origin_master_url): + cur_line.append(origin_master_url) + cur_line.append(str(res1[0][1])) + if j < 8: + male_cnt += 1 + else: + female_cnt += 1 + continue + cur_line.append("-1") + cur_line.append("-1") + if male_cnt > 0 and female_cnt > 0: + out_arr.append(cur_line) + if i % 10 == 0: + print(f"percent={i}/{len(arr)} sp = {time.time() - st}, cnt={len(out_arr)}") + + with open("new_vocal.txt", "w") as f: + for line in out_arr: + f.write(",".join(line) + "\n") + + if __name__ == '__main__': - process() + get_vocal_url() + # process() # put_data("res/0414_0514.csv") # arr = [ # "5348024335101054", # "4222124657245641", # "5629499489117674", # "12384898975368914", # "5629499489839033", # "5348024336648185", # "5910973794961321", # "3635518643", # "844424937670811", # "4785074600577375", # "6755399442719465", # "4785074603156924", # "11540474053041727", # "6473924129711210", # "7036874421386111", # "7599824376482810", # "6755399447475416", # "8444249306118343", # "3377699721107378", # "12947848931397021", # "7599824374449011", # "3096224748076687", # "12103424006572822", # "1125899914308640", # "12666373952417962", # "281474982845813", # "11821949029679778", # "12947848937379499", # "12947848936090348", # "3096224747262571", # "2814749767432467", # "5066549357604730", # "3096224751151928" # ] # for uuid in arr: # get_user_recordings(uuid) # print("finish =={} ".format(uuid)) # copy_data()