fix: index_root searching

close #1147
This commit is contained in:
源文雨
2023-09-01 14:11:55 +08:00
parent d634c2727e
commit 8ffdcb0128
21 changed files with 59 additions and 76 deletions

View File

@@ -23,7 +23,7 @@ vc = VC(config)
weight_root = os.getenv("weight_root")
weight_uvr5_root = os.getenv("weight_uvr5_root")
index_root = "logs"
index_root = os.getenv("index_root")
names = []
hubert_model = None
for name in os.listdir(weight_root):

View File

@@ -55,7 +55,7 @@ def main(path, root):
torch.manual_seed(114514)
model_a = torch.load(path, map_location="cpu")["weight"]
print("query:\t\t%s\t%s" % (path, model_hash(path)))
print("Query:\t\t%s\t%s" % (path, model_hash(path)))
map_attn_a = {}
map_rand_input = {}
@@ -83,7 +83,7 @@ def main(path, root):
sims.append(sim)
print(
"reference:\t%s\t%s\t%s"
"Reference:\t%s\t%s\t%s"
% (path, model_hash(path), f"{torch.mean(torch.stack(sims)) * 1e2:.2f}%")
)

View File

@@ -34,7 +34,7 @@ from scipy.io import wavfile
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = r"E:\codes\py39\vits_vc_gpu_train\assets\hubert\hubert_base.pt" #
print("load model(s) from {}".format(model_path))
print("Load model(s) from {}".format(model_path))
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
[model_path],
suffix="",

View File

@@ -50,14 +50,14 @@ np.save("tools/infer/big_src_feature_mi.npy", big_npy)
# big_npy=np.load("/bili-coeus/jupyter/jupyterhub-liujing04/vits_ch/inference_f0/big_src_feature_mi.npy")
n_ivf = min(int(16 * np.sqrt(big_npy.shape[0])), big_npy.shape[0] // 39)
index = faiss.index_factory(768, "IVF%s,Flat" % n_ivf) # mi
print("training")
print("Training...")
index_ivf = faiss.extract_index_ivf(index) #
index_ivf.nprobe = 1
index.train(big_npy)
faiss.write_index(
index, "tools/infer/trained_IVF%s_Flat_baseline_src_feat_v2.index" % (n_ivf)
)
print("adding")
print("Adding...")
batch_size_add = 8192
for i in range(0, big_npy.shape[0], batch_size_add):
index.add(big_npy[i : i + batch_size_add])

View File

@@ -20,12 +20,12 @@ np.save("infer/big_src_feature_mi.npy", big_npy)
# big_npy=np.load("/bili-coeus/jupyter/jupyterhub-liujing04/vits_ch/inference_f0/big_src_feature_mi.npy")
print(big_npy.shape)
index = faiss.index_factory(256, "IVF512,Flat") # mi
print("training")
print("Training...")
index_ivf = faiss.extract_index_ivf(index) #
index_ivf.nprobe = 9
index.train(big_npy)
faiss.write_index(index, "infer/trained_IVF512_Flat_mi_baseline_src_feat.index")
print("adding")
print("Adding...")
index.add(big_npy)
faiss.write_index(index, "infer/added_IVF512_Flat_mi_baseline_src_feat.index")
"""

View File

@@ -67,7 +67,7 @@ class RVC:
if index_rate != 0:
self.index = faiss.read_index(index_path)
self.big_npy = self.index.reconstruct_n(0, self.index.ntotal)
print("index search enabled")
print("Index search enabled")
self.index_path = index_path
self.index_rate = index_rate
models, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
@@ -120,7 +120,7 @@ class RVC:
if new_index_rate != 0 and self.index_rate == 0:
self.index = faiss.read_index(self.index_path)
self.big_npy = self.index.reconstruct_n(0, self.index.ntotal)
print("index search enabled")
print("Index search enabled")
self.index_rate = new_index_rate
def get_f0_post(self, f0):
@@ -237,7 +237,7 @@ class RVC:
if hasattr(self, "model_rmvpe") == False:
from infer.lib.rmvpe import RMVPE
print("loading rmvpe model")
print("Loading rmvpe model")
self.model_rmvpe = RMVPE(
# "rmvpe.pt", is_half=self.is_half if self.device.type!="privateuseone" else False, device=self.device if self.device.type!="privateuseone"else "cpu"####dml时强制对rmvpe用cpu跑
# "rmvpe.pt", is_half=False, device=self.device####dml配置
@@ -295,10 +295,10 @@ class RVC:
+ (1 - self.index_rate) * feats[0][-leng_replace_head:]
)
else:
print("index search FAIL or disabled")
print("Index search FAILED or disabled")
except:
traceback.print_exc()
print("index search FAIL")
print("Index search FAILED")
feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
t3 = ttime()
if self.if_f0 == 1:
@@ -338,5 +338,5 @@ class RVC:
.float()
)
t5 = ttime()
print("time->fea-index-f0-model:", t2 - t1, t3 - t2, t4 - t3, t5 - t4)
print("Spent time: fea =", t2 - t1, ", index =", t3 - t2, ", f0 =", t4 - t3, ", model =", t5 - t4)
return infered_audio