diff --git a/AutoCoverTool/svc_inference/webui.py b/AutoCoverTool/svc_inference/webui.py index 9b5de2a..48a7031 100644 --- a/AutoCoverTool/svc_inference/webui.py +++ b/AutoCoverTool/svc_inference/webui.py @@ -1,92 +1,76 @@ """ 构建唱歌音色转换网页(基于3.0) 要求: 1. 音频上传 2. 推理 3. 下载 """ import os import time import glob import shutil import librosa import soundfile import gradio as gr -from online.common import update_db -from ref.so_vits_svc.inference_main import * +from online.inference_one import inf -gs_tmp_dir = "/tmp/svc_inference" +gs_tmp_dir = "/tmp/svc_inference_one_web" gs_model_dir = "/data/prod/so_vits_models/3.0" -gs_test_wav_dir = "/data/prod/so_vits_models/test_svc_file/3.0" gs_config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "config.json") - - -def generate_svc_file(): - """ - :return: - """ - if not os.path.exists(gs_test_wav_dir): - os.makedirs(gs_test_wav_dir) - test_wav_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../res/syz_test.wav") - model_path_list = glob.glob(os.path.join(gs_model_dir, "*/*pth")) +gs_models_choices = glob.glob(os.path.join(gs_model_dir, "*/*pth")) +gs_model_list_dropdown = None + + +def svc(audio_data, model_path): + sr, data = audio_data + if os.path.exists(gs_tmp_dir): + shutil.rmtree(gs_tmp_dir) + os.makedirs(gs_tmp_dir) + tmp_path = os.path.join(gs_tmp_dir, "tmp.wav") + soundfile.write(tmp_path, data, sr, format="wav") + + # 重采样到32k + audio, sr = librosa.load(tmp_path, sr=32000, mono=True) + tmp_path = os.path.join(gs_tmp_dir, "tmp_32.wav") + out_path = os.path.join(gs_tmp_dir, "out.wav") + soundfile.write(tmp_path, data, sr, format="wav") + + # 推理 + print("svc: {}".format(model_path)) st = time.time() - for idx, model_path in enumerate(model_path_list): - model_name = model_path.strip().split("/")[-1].replace(".pth", "") - dst_path = os.path.join(gs_test_wav_dir, "{}.wav".format(model_name)) - if not os.path.exists(dst_path): - inf(model_path, gs_config_path, test_wav_path, dst_path, "prod") - print("now_per={}/{}".format(idx, len(model_path_list), time.time() - st)) - - -def update_state(gender, user_id): - sql = "update av_db.av_svc_model set gender={} where user_id=\"{}\"".format(gender, user_id) - update_db(sql) - - -# 按钮控制 -def click_male(user_id): - print("click_male={}".format(user_id)) - pass + inf(model_path, gs_config_path, tmp_path, out_path, 'cuda') + print("input d={}, sp = {}".format(len(audio) / sr, time.time() - st)) + return out_path -def click_female(user_id): - print("click_female={}".format(user_id)) - - -def click_delete(user_id): - print("click_delete={}".format(user_id)) +def model_select(): + files = glob.glob(os.path.join(gs_model_dir, "*/*pth")) + return gs_model_list_dropdown.update(choices=files) def main(): # header - st = time.time() - generate_svc_file() - print("generate svc sp={}".format(time.time() - st)) - app = gr.Blocks() with app: # 头部介绍 gr.Markdown(value=""" - ### 人声质量评价 + ### 唱歌音色转换 作者:starmaker音视频 """) - # 列表展示 - # 1. 每一行有音频,性别,删除等按钮 - svc_files = glob.glob(os.path.join(gs_test_wav_dir, "*wav")) - for svc_file in svc_files: - user_id = svc_file.split("/")[-1].replace(".wav", "") - gr.Audio(source=svc_file) - - male_gender_btn = gr.Button("male") - female_gender_btn = gr.Button("female") - del_btn = gr.Button("female") - male_gender_btn.click(click_male, inputs=[user_id]) - female_gender_btn.click(click_female, inputs=[user_id]) - del_btn.click(click_delete, inputs=[user_id]) + global gs_model_list_dropdown + gs_model_list_dropdown = gr.Dropdown(choices=gs_models_choices, interactive=True, label="model list") + refresh_btn = gr.Button("refresh_model_list") + refresh_btn.click(fn=model_select, inputs=[], outputs=gs_model_list_dropdown) + + # 提示词输入框 + input_audio = gr.inputs.Audio(label="input") + gen_btn = gr.Button("generate", variant="primary") + output_audio = gr.outputs.Audio(label="output", type='filepath') + gen_btn.click(fn=svc, inputs=[input_audio, gs_model_list_dropdown], outputs=output_audio) # 本方法实现同一时刻只有一个程序在服务器端运行 app.queue(concurrency_count=1, max_size=2044).launch(server_name="0.0.0.0", inbrowser=True, quiet=True, server_port=7860) if __name__ == '__main__': main()