optimize: realtime inference (#1693)

* update real-time gui

* update real-time gui

* update real-time gui
This commit is contained in:
yxlllc
2024-01-16 19:22:55 +08:00
committed by GitHub
parent 26e2805f0e
commit c3e65cdf96
14 changed files with 143 additions and 96 deletions

View File

@@ -91,8 +91,8 @@ class RVC:
self.pth_path: str = pth_path
self.index_path = index_path
self.index_rate = index_rate
self.cache_pitch: np.ndarray = np.zeros(1024, dtype="int32")
self.cache_pitchf = np.zeros(1024, dtype="float32")
self.cache_pitch: torch.Tensor = torch.zeros(1024, device=self.device, dtype=torch.long)
self.cache_pitchf = torch.zeros(1024, device=self.device, dtype=torch.float32)
if last_rvc is None:
models, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
@@ -199,15 +199,17 @@ class RVC:
self.index_rate = new_index_rate
def get_f0_post(self, f0):
f0bak = f0.copy()
f0_mel = 1127 * np.log(1 + f0 / 700)
if not torch.is_tensor(f0):
f0 = torch.from_numpy(f0)
f0 = f0.float().to(self.device).squeeze()
f0_mel = 1127 * torch.log(1 + f0 / 700)
f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - self.f0_mel_min) * 254 / (
self.f0_mel_max - self.f0_mel_min
) + 1
f0_mel[f0_mel <= 1] = 1
f0_mel[f0_mel > 255] = 255
f0_coarse = np.rint(f0_mel).astype(np.int32)
return f0_coarse, f0bak
f0_coarse = torch.round(f0_mel).long()
return f0_coarse, f0
def get_f0(self, x, f0_up_key, n_cpu, method="harvest"):
n_cpu = int(n_cpu)
@@ -299,14 +301,12 @@ class RVC:
pd = torchcrepe.filter.median(pd, 3)
f0 = torchcrepe.filter.mean(f0, 3)
f0[pd < 0.1] = 0
f0 = f0[0].cpu().numpy()
f0 *= pow(2, f0_up_key / 12)
return self.get_f0_post(f0)
def get_f0_rmvpe(self, x, f0_up_key):
if hasattr(self, "model_rmvpe") == False:
from infer.lib.rmvpe import RMVPE
printt("Loading rmvpe model")
self.model_rmvpe = RMVPE(
"assets/rmvpe/rmvpe.pt",
@@ -335,7 +335,6 @@ class RVC:
threshold=0.006,
)
f0 *= pow(2, f0_up_key / 12)
f0 = f0.squeeze().cpu().numpy()
return self.get_f0_post(f0)
def infer(
@@ -383,6 +382,7 @@ class RVC:
traceback.print_exc()
printt("Index search FAILED")
t3 = ttime()
p_len = input_wav.shape[0] // 160
if self.if_f0 == 1:
f0_extractor_frame = block_frame_16k + 800
if f0method == "rmvpe":
@@ -390,25 +390,14 @@ class RVC:
pitch, pitchf = self.get_f0(
input_wav[-f0_extractor_frame:], self.f0_up_key, self.n_cpu, f0method
)
start_frame = block_frame_16k // 160
end_frame = len(self.cache_pitch) - (pitch.shape[0] - 4) + start_frame
self.cache_pitch[:] = np.append(
self.cache_pitch[start_frame:end_frame], pitch[3:-1]
)
self.cache_pitchf[:] = np.append(
self.cache_pitchf[start_frame:end_frame], pitchf[3:-1]
)
shift = block_frame_16k // 160
self.cache_pitch[: -shift] = self.cache_pitch[shift :].clone()
self.cache_pitchf[: -shift] = self.cache_pitchf[shift :].clone()
self.cache_pitch[4 - pitch.shape[0] :] = pitch[3:-1]
self.cache_pitchf[4 - pitch.shape[0] :] = pitchf[3:-1]
cache_pitch = self.cache_pitch[None, -p_len:]
cache_pitchf = self.cache_pitchf[None, -p_len:]
t4 = ttime()
p_len = input_wav.shape[0] // 160
if self.if_f0 == 1:
cache_pitch = (
torch.LongTensor(self.cache_pitch[-p_len:]).to(self.device).unsqueeze(0)
)
cache_pitchf = (
torch.FloatTensor(self.cache_pitchf[-p_len:])
.to(self.device)
.unsqueeze(0)
)
feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
feats = feats[:, :p_len, :]
p_len = torch.LongTensor([p_len]).to(self.device)