mirror of
https://github.com/liuhaozhe6788/voice-cloning-collab.git
synced 2025-12-21 05:59:46 +01:00
new commits
This commit is contained in:
76
encoder/data_objects/speaker_verification_dataset.py
Normal file
76
encoder/data_objects/speaker_verification_dataset.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from encoder.data_objects.random_cycler import RandomCycler
|
||||
from encoder.data_objects.speaker_batch import SpeakerBatch
|
||||
from encoder.data_objects.utterance_batch import UtteranceBatch
|
||||
from encoder.data_objects.speaker import Speaker
|
||||
from encoder.params_data import partials_n_frames
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from pathlib import Path
|
||||
from os import listdir
|
||||
from os.path import isfile
|
||||
import numpy as np
|
||||
|
||||
# TODO: improve with a pool of speakers for data efficiency
|
||||
|
||||
class Train_Dataset(Dataset):
|
||||
def __init__(self, datasets_root: Path):
|
||||
self.root = datasets_root
|
||||
speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()]
|
||||
if len(speaker_dirs) == 0:
|
||||
raise Exception("No speakers found. Make sure you are pointing to the directory "
|
||||
"containing all preprocessed speaker directories.")
|
||||
self.speakers = [Speaker(speaker_dir) for speaker_dir in speaker_dirs]
|
||||
self.speaker_cycler = RandomCycler(self.speakers)
|
||||
|
||||
def __len__(self):
|
||||
return int(1e8)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return next(self.speaker_cycler)
|
||||
|
||||
def get_logs(self):
|
||||
log_string = ""
|
||||
for log_fpath in self.root.glob("*.txt"):
|
||||
with log_fpath.open("r") as log_file:
|
||||
log_string += "".join(log_file.readlines())
|
||||
return log_string
|
||||
|
||||
|
||||
class Dev_Dataset(Dataset):
|
||||
def __init__(self, datasets_root: Path):
|
||||
self.root = datasets_root
|
||||
speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()]
|
||||
if len(speaker_dirs) == 0:
|
||||
raise Exception("No speakers found. Make sure you are pointing to the directory "
|
||||
"containing all preprocessed speaker directories.")
|
||||
self.speakers = [Speaker(speaker_dir) for speaker_dir in speaker_dirs]
|
||||
self.speaker_cycler = RandomCycler(self.speakers)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.speakers)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return next(self.speaker_cycler)
|
||||
|
||||
|
||||
class DataLoader(DataLoader):
|
||||
def __init__(self, dataset, speakers_per_batch, utterances_per_speaker, shuffle, sampler=None,
|
||||
batch_sampler=None, num_workers=0, pin_memory=False, timeout=0,
|
||||
worker_init_fn=None):
|
||||
self.utterances_per_speaker = utterances_per_speaker
|
||||
|
||||
super().__init__(
|
||||
dataset=dataset,
|
||||
batch_size=speakers_per_batch,
|
||||
shuffle=shuffle,
|
||||
sampler=sampler,
|
||||
batch_sampler=batch_sampler,
|
||||
num_workers=num_workers,
|
||||
collate_fn=self.collate,
|
||||
pin_memory=pin_memory,
|
||||
drop_last=False,
|
||||
timeout=timeout,
|
||||
worker_init_fn=worker_init_fn
|
||||
)
|
||||
|
||||
def collate(self, speakers):
|
||||
return SpeakerBatch(speakers, self.utterances_per_speaker, partials_n_frames)
|
||||
Reference in New Issue
Block a user