fix: 卸载音色省显存

顺便将所有print换成了统一的logger
This commit is contained in:
源文雨
2023-09-01 15:18:08 +08:00
parent 8d5a77dbe9
commit 04a33b9709
23 changed files with 189 additions and 106 deletions

View File

@@ -13,9 +13,10 @@ logging.getLogger("numba").setLevel(logging.WARNING)
logging.getLogger("markdown_it").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)
logging.getLogger("matplotlib").setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
i18n = I18nAuto()
i18n.print()
logger.info(i18n)
load_dotenv()
config = Config()

View File

@@ -1,7 +1,8 @@
# This code references https://huggingface.co/JosephusCheung/ASimilarityCalculatior/blob/main/qwerty.py
# Fill in the path of the model to be queried and the root directory of the reference models, and this script will return the similarity between the model to be queried and all reference models.
import os
import sys
import logging
logger = logging.getLogger(__name__)
import torch
import torch.nn as nn
@@ -55,7 +56,7 @@ def main(path, root):
torch.manual_seed(114514)
model_a = torch.load(path, map_location="cpu")["weight"]
print("Query:\t\t%s\t%s" % (path, model_hash(path)))
logger.info("Query:\t\t%s\t%s" % (path, model_hash(path)))
map_attn_a = {}
map_rand_input = {}
@@ -82,7 +83,7 @@ def main(path, root):
sim = torch.mean(torch.cosine_similarity(attn_a, attn_b))
sims.append(sim)
print(
logger.info(
"Reference:\t%s\t%s\t%s"
% (path, model_hash(path), f"{torch.mean(torch.stack(sims)) * 1e2:.2f}%")
)

View File

@@ -3,7 +3,8 @@
对源特征进行检索
"""
import os
import pdb
import logging
logger = logging.getLogger(__name__)
import parselmouth
import torch
@@ -15,7 +16,6 @@ from time import time as ttime
# import pyworld
import librosa
import numpy as np
import scipy.signal as signal
import soundfile as sf
import torch.nn.functional as F
from fairseq import checkpoint_utils
@@ -34,7 +34,7 @@ from scipy.io import wavfile
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = r"E:\codes\py39\vits_vc_gpu_train\assets\hubert\hubert_base.pt" #
print("Load model(s) from {}".format(model_path))
logger.info("Load model(s) from {}".format(model_path))
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
[model_path],
suffix="",
@@ -77,7 +77,7 @@ net_g = SynthesizerTrn256(
# weights=torch.load("infer/ft-mi-freeze-vocoder_true_1k.pt")
# weights=torch.load("infer/ft-mi-sim1k.pt")
weights = torch.load("infer/ft-mi-no_opt-no_dropout.pt")
print(net_g.load_state_dict(weights, strict=True))
logger.debug(net_g.load_state_dict(weights, strict=True))
net_g.eval().to(device)
net_g.half()
@@ -198,4 +198,4 @@ for idx, name in enumerate(
wavfile.write("ft-mi-no_opt-no_dropout-%s.wav" % name, 40000, audio) ##
print(ta0, ta1, ta2) #
logger.debug(ta0, ta1, ta2) #

View File

@@ -3,6 +3,9 @@
"""
import os
import traceback
import logging
logger = logging.getLogger(__name__)
from multiprocessing import cpu_count
import faiss
@@ -23,11 +26,11 @@ big_npy = np.concatenate(npys, 0)
big_npy_idx = np.arange(big_npy.shape[0])
np.random.shuffle(big_npy_idx)
big_npy = big_npy[big_npy_idx]
print(big_npy.shape) # (6196072, 192)#fp32#4.43G
logger.debug(big_npy.shape) # (6196072, 192)#fp32#4.43G
if big_npy.shape[0] > 2e5:
# if(1):
info = "Trying doing kmeans %s shape to 10k centers." % big_npy.shape[0]
print(info)
logger.info(info)
try:
big_npy = (
MiniBatchKMeans(
@@ -42,7 +45,7 @@ if big_npy.shape[0] > 2e5:
)
except:
info = traceback.format_exc()
print(info)
logger.warn(info)
np.save("tools/infer/big_src_feature_mi.npy", big_npy)
@@ -50,14 +53,14 @@ np.save("tools/infer/big_src_feature_mi.npy", big_npy)
# big_npy=np.load("/bili-coeus/jupyter/jupyterhub-liujing04/vits_ch/inference_f0/big_src_feature_mi.npy")
n_ivf = min(int(16 * np.sqrt(big_npy.shape[0])), big_npy.shape[0] // 39)
index = faiss.index_factory(768, "IVF%s,Flat" % n_ivf) # mi
print("Training...")
logger.info("Training...")
index_ivf = faiss.extract_index_ivf(index) #
index_ivf.nprobe = 1
index.train(big_npy)
faiss.write_index(
index, "tools/infer/trained_IVF%s_Flat_baseline_src_feat_v2.index" % (n_ivf)
)
print("Adding...")
logger.info("Adding...")
batch_size_add = 8192
for i in range(0, big_npy.shape[0], batch_size_add):
index.add(big_npy[i : i + batch_size_add])

View File

@@ -2,6 +2,8 @@
格式直接cid为自带的index位aid放不下了通过字典来查反正就5w个
"""
import os
import logging
logger = logging.getLogger(__name__)
import faiss
import numpy as np
@@ -13,19 +15,19 @@ for name in sorted(list(os.listdir(inp_root))):
phone = np.load("%s/%s" % (inp_root, name))
npys.append(phone)
big_npy = np.concatenate(npys, 0)
print(big_npy.shape) # (6196072, 192)#fp32#4.43G
logger.debug(big_npy.shape) # (6196072, 192)#fp32#4.43G
np.save("infer/big_src_feature_mi.npy", big_npy)
##################train+add
# big_npy=np.load("/bili-coeus/jupyter/jupyterhub-liujing04/vits_ch/inference_f0/big_src_feature_mi.npy")
print(big_npy.shape)
logger.debug(big_npy.shape)
index = faiss.index_factory(256, "IVF512,Flat") # mi
print("Training...")
logger.info("Training...")
index_ivf = faiss.extract_index_ivf(index) #
index_ivf.nprobe = 9
index.train(big_npy)
faiss.write_index(index, "infer/trained_IVF512_Flat_mi_baseline_src_feat.index")
print("Adding...")
logger.info("Adding...")
index.add(big_npy)
faiss.write_index(index, "infer/added_IVF512_Flat_mi_baseline_src_feat.index")
"""

View File

@@ -1,6 +1,9 @@
import os
import sys
import traceback
import logging
logger = logging.getLogger(__name__)
from time import time as ttime
import fairseq
@@ -67,7 +70,7 @@ class RVC:
if index_rate != 0:
self.index = faiss.read_index(index_path)
self.big_npy = self.index.reconstruct_n(0, self.index.ntotal)
print("Index search enabled")
logger.info("Index search enabled")
self.index_path = index_path
self.index_rate = index_rate
models, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
@@ -102,7 +105,7 @@ class RVC:
else:
self.net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
del self.net_g.enc_q
print(self.net_g.load_state_dict(cpt["weight"], strict=False))
logger.debug(self.net_g.load_state_dict(cpt["weight"], strict=False))
self.net_g.eval().to(device)
# print(2333333333,device,config.device,self.device)#net_g是devicehubert是config.device
if config.is_half:
@@ -111,7 +114,7 @@ class RVC:
self.net_g = self.net_g.float()
self.is_half = config.is_half
except:
print(traceback.format_exc())
logger.warn(traceback.format_exc())
def change_key(self, new_key):
self.f0_up_key = new_key
@@ -120,7 +123,7 @@ class RVC:
if new_index_rate != 0 and self.index_rate == 0:
self.index = faiss.read_index(self.index_path)
self.big_npy = self.index.reconstruct_n(0, self.index.ntotal)
print("Index search enabled")
logger.info("Index search enabled")
self.index_rate = new_index_rate
def get_f0_post(self, f0):
@@ -237,7 +240,7 @@ class RVC:
if hasattr(self, "model_rmvpe") == False:
from infer.lib.rmvpe import RMVPE
print("Loading rmvpe model")
logger.info("Loading rmvpe model")
self.model_rmvpe = RMVPE(
# "rmvpe.pt", is_half=self.is_half if self.device.type!="privateuseone" else False, device=self.device if self.device.type!="privateuseone"else "cpu"####dml时强制对rmvpe用cpu跑
# "rmvpe.pt", is_half=False, device=self.device####dml配置
@@ -295,10 +298,10 @@ class RVC:
+ (1 - self.index_rate) * feats[0][-leng_replace_head:]
)
else:
print("Index search FAILED or disabled")
logger.warn("Index search FAILED or disabled")
except:
traceback.print_exc()
print("Index search FAILED")
logger.warn("Index search FAILED")
feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
t3 = ttime()
if self.if_f0 == 1:
@@ -338,5 +341,5 @@ class RVC:
.float()
)
t5 = ttime()
print("Spent time: fea =", t2 - t1, ", index =", t3 - t2, ", f0 =", t4 - t3, ", model =", t5 - t4)
logger.info("Spent time: fea =", t2 - t1, ", index =", t3 - t2, ", f0 =", t4 - t3, ", model =", t5 - t4)
return infered_audio