This commit is contained in:
jason-on-salt-a40
2024-03-21 11:02:20 -07:00
commit 6760f29bd0
32 changed files with 9321 additions and 0 deletions

0
data/__init__.py Normal file
View File

View File

@@ -0,0 +1,160 @@
import argparse
def parse_args():
parser = argparse.ArgumentParser(description="encode the librilight dataset using encodec model")
parser.add_argument("--manifest_root", type=str, default="/home/pyp/audiocraft/egs/gigaspeech", help="this the dir of the audiocraft manifest!")
parser.add_argument('--audio_dir', type=str, default="/data/scratch/pyp/datasets/gigaspeech_flac", help="Path dirs of the flac audio files")
parser.add_argument('--save_dir', type=str, default="/data/scratch/pyp/datasets/gigaspeech_phn_enc_manifest/xl", help="path to the manifest, phonemes, and encodec codes dirs")
parser.add_argument('--encodec_model_path', type=str, default="/data/scratch/pyp/exp_pyp/audiocraft/encodec/xps/6f79c6a8/checkpoint.th")
parser.add_argument('--n_workers', type=int, default=32, help="Number of parallel worker processes")
parser.add_argument('--batch_size', type=int, default=64, help="batch size for encodec encoding, decrease it if OOM. This is the sum of batch size *over each gpu*, so increase it if you are using more gpus")
parser.add_argument('--model_sr', type=int, default=16000, help='encodec input audio sample rate')
parser.add_argument('--downsample_rate', type=int, default=320, help='encodec downsample rate')
parser.add_argument('--model_code_sr', type=int, default=50, help='encodec model code sample rate')
parser.add_argument('--len_cap', type=float, default=35.0, help='will drop audios that are longer than this number')
return parser.parse_args()
if __name__ == "__main__":
import logging
formatter = (
"%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
import os
import numpy as np
import torch
import torchaudio
import tqdm
import time
args = parse_args()
manifest_dir = args.manifest_root # this dir is scp-ed
audio_dir = args.audio_dir # this is scp-ed flac dir
encodec_signature = args.encodec_model_path.split("/")[-2]
save_codes_dir = os.path.join(args.save_dir, f"encodec_16khz_{encodec_signature}")
os.makedirs(save_codes_dir, exist_ok=True)
# model_sr = 16000
# downsample_rate = 320
# model_code_sr = 50
def sort_by_audio_len(lens):
inds = np.argsort(lens).tolist()
logging.info(f"longest: {lens[inds[-1]]/args.downsample_rate} encodec codes, {lens[inds[-1]]/args.model_sr:.2f} sec.")
logging.info(f"shortest: {lens[inds[0]]/args.downsample_rate} encodec codes, {lens[inds[0]]/args.model_sr:.2f} sec.")
logging.info(f"median: {lens[inds[len(inds)//2]]/args.downsample_rate} encodec codes, {lens[inds[len(inds)//2]]/args.model_sr:.2f} sec.")
logging.info(f"95 percentile longest: {lens[inds[int(len(inds)*0.95)]]/args.downsample_rate} encodec codes, {lens[inds[int(len(inds)*0.95)]]/args.model_sr:.2f} sec.")
return inds[::-1]
def write_array_to_txt_file(array, filename):
with open(filename, 'w') as f:
for a in array[:-1]:
f.write(' '.join(map(str, a))+'\n')
f.write(' '.join(map(str, array[-1])))
class mydataset(torch.utils.data.Dataset):
def __init__(self, split):
super().__init__()
# self.data = gs[split]
self.split = split
self.audio_root = audio_dir
manifest_fn = os.path.join(manifest_dir, split+".txt")
with open(manifest_fn, "r") as rf:
self.data = [l.strip().split("\t") for l in rf.readlines()]
def __len__(self):
return len(self.data)
def __getitem__(self, ind):
try:
afn = self.data[ind][0]
fn = os.path.join(self.audio_root, afn)
audio, sr = torchaudio.load(fn)
assert sr == args.model_sr, sr
except Exception as e:
logging.info(f"{e}")
return None, None, None
assert audio.ndim==2 and audio.shape[0] == 1, audio.shape
return audio.type(torch.float32).squeeze(0), audio.shape[-1], os.path.basename(afn).split(".")[0]
def collate(self, batch):
lens, audios, segment_ids = [], [], []
for item in batch:
if item[0] != None:
audios.append(item[0])
lens.append(item[1])
segment_ids.append(item[2])
return audios, lens, segment_ids
# load the encodec model
from audiocraft.solvers import CompressionSolver
model = CompressionSolver.model_from_checkpoint(args.encodec_model_path)
model = model.cuda()
model = model.eval()
model = torch.nn.DataParallel(model)
# setup dataloader
mega_batch_size = 2100
batch_size = args.batch_size
train_dataset = mydataset('train')
train_loader = torch.torch.utils.data.DataLoader(train_dataset, batch_size=mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=train_dataset.collate)
validation_dataset = mydataset('validation')
validation_loader = torch.torch.utils.data.DataLoader(validation_dataset, batch_size=mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=validation_dataset.collate)
test_dataset = mydataset('test')
test_loader = torch.torch.utils.data.DataLoader(test_dataset, batch_size=mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=test_dataset.collate)
splits = ['validation', 'test', 'train']
loaders = [validation_loader, test_loader, train_loader]
# splits = ['validation'] # NOTE this is for debug, for example, see if the
# loaders = [validation_loader]
for split, loader in zip(splits, loaders):
skip = 0
logging.info(f"now processing split {split}...")
mega_n_steps = int(np.ceil(len(loader.dataset) / mega_batch_size))
# mega_n_steps = int(np.ceil(len(gs) / mega_batch_size))
logging.info(f"partition the split {split} into {mega_n_steps} parts, each has {mega_batch_size} samples")
# with open(mani_fn, "a") as mani_wf: # resume from where we failed
for m, mega_batch in enumerate(loader):
logging.info(f"====================================")
logging.info(f"====================================")
logging.info(f"now processing mega step {m+1}/{mega_n_steps}")
lengths = np.array(mega_batch[1])
sorted_inds = sort_by_audio_len(lengths)
for j in range(len(sorted_inds))[::-1]:
if lengths[sorted_inds[j]] < args.model_sr*0.2 or lengths[sorted_inds[j]] > args.model_sr*args.len_cap: # skip samples that are too short (shorter than 0.2s), or too big (bigger than 80s)
skip += 1
del sorted_inds[j]
n_steps = int(np.ceil(len(sorted_inds) / batch_size))
for n in tqdm.tqdm(range(n_steps), disable=True):
inds_used = sorted_inds[n*batch_size:(n+1)*batch_size]
wav_batch = [mega_batch[0][id] for id in inds_used]
all_lens = [mega_batch[1][id] for id in inds_used]
segment_id_batch = [mega_batch[2][id] for id in inds_used]
# print(segment_id_batch)
padded_wav = torch.nn.utils.rnn.pad_sequence(wav_batch, batch_first=True).unsqueeze(1) # [B, T] -> [B, 1, T]
with torch.no_grad():
if max(all_lens) > 300000 and len(all_lens) > 1: # NOTE decrease this (300000) if OOM, or chunk it into more than 2 forward passes
codes = []
inwav = padded_wav.cuda()
codes.append(model(inwav[:len(inwav)//2], encode=True)[0].cpu())
codes.append(model(inwav[len(inwav)//2:], encode=True)[0].cpu())
codes = torch.cat(codes, dim=0)
else:
encoded_frames = model(padded_wav.cuda(), encode=True) # wav needs to have shape [B, C, T], C is model.channels, which is 1 for the 24kHz encodec model
# logging.info(f"encoded_frames: {encoded_frames[0].shape}")
codes = encoded_frames[0].cpu()
for i, length in enumerate(all_lens):
save_fn = os.path.join(save_codes_dir, segment_id_batch[i]+".txt")
actual_len = round(length / args.downsample_rate) # 320 is downsample rate for this model
cur_code = codes[i].tolist() if type(codes) == list else codes[i, :, :actual_len].tolist()
write_array_to_txt_file(cur_code, save_fn)
# mani_wf.write(f"0\t{segment_id_batch[i]}\t{len(cur_code[0])}\n") # write to manifest file
# if i == 10:
# raise
# break
# logging.info(f"split {split} has {len(gs[split])} samples in total, skipped {skip} due to forbiden words")
logging.info(f"split {split} has {len(loader.dataset)} samples in total, skipped {skip} due to utterance being too long or too short")
# break

158
data/gigaspeech.py Normal file
View File

@@ -0,0 +1,158 @@
import os
import torch
import random
import copy
import logging
import shutil
class dataset(torch.utils.data.Dataset):
def __init__(self, args, split):
super().__init__()
self.args = args
self.split = split
assert self.split in ['train', 'validation', 'test']
manifest_fn = os.path.join(self.args.dataset_dir, self.args.manifest_name, self.split+".txt")
with open(manifest_fn, "r") as rf:
data = [l.strip().split("\t") for l in rf.readlines()]
lengths_list = [int(item[-1]) for item in data]
self.data = []
self.lengths_list = []
for d, l in zip(data, lengths_list):
if l >= self.args.encodec_sr*self.args.audio_min_length:
if self.args.drop_long and l > self.args.encodec_sr*self.args.audio_max_length:
continue
self.data.append(d)
self.lengths_list.append(l)
logging.info(f"number of data points for {self.split} split: {len(self.lengths_list)}")
# phoneme vocabulary
vocab_fn = os.path.join(self.args.dataset_dir,"vocab.txt")
shutil.copy(vocab_fn, os.path.join(self.args.exp_dir, "vocab.txt"))
with open(vocab_fn, "r") as f:
temp = [l.strip().split(" ") for l in f.readlines() if len(l) != 0]
self.phn2num = {item[1]:int(item[0]) for item in temp}
self.symbol_set = set(["<SIL>", "<MUSIC>", "<NOISE>", "<OTHER>"])
def __len__(self):
return len(self.lengths_list)
def _load_phn_enc(self, index):
item = self.data[index]
pf = os.path.join(self.args.dataset_dir, self.args.phn_folder_name, item[1]+".txt")
ef = os.path.join(self.args.dataset_dir, self.args.encodec_folder_name, item[1]+".txt")
try:
with open(pf, "r") as p, open(ef, "r") as e:
phns = [l.strip() for l in p.readlines()]
assert len(phns) == 1, phns
x = [self.phn2num[item] for item in phns[0].split(" ") if item not in self.symbol_set] # drop ["<SIL>", "<MUSIC>", "<NOISE>", "<OTHER>"], as they are not in training set annotation
encos = [l.strip().split() for k, l in enumerate(e.readlines()) if k < self.args.n_codebooks]
assert len(encos) == self.args.n_codebooks, ef
if self.args.special_first:
y = [[int(n)+self.args.n_special for n in l] for l in encos]
else:
y = [[int(n) for n in l] for l in encos]
if self.args.training_stage == 1 and not self.args.valle and not (self.args.musicgen or self.args.valle_orig):
y = y[:1]
except Exception as e:
logging.info(f"loading failed for {pf} and {ef}, maybe files don't exist or are corrupted")
logging.info(f"error message: {e}")
return [], [[]]
return x, y
def __getitem__(self, index):
x, y = self._load_phn_enc(index)
x_len, y_len = len(x), len(y[0])
if x_len == 0 or y_len == 0:
return {
"x": None,
"x_len": None,
"y": None,
"y_len": None,
"y_mask_interval": None, # index y_mask_interval[1] is the position of start_of_continue token
"extra_mask_start": None # this is only used in VE1
}
while y_len < self.args.encodec_sr*self.args.audio_min_length:
assert not self.args.dynamic_batching
index = random.choice(range(len(self))) # regenerate an index
x, y = self._load_phn_enc(index)
x_len, y_len = len(x), len(y[0])
if self.args.drop_long:
while x_len > self.args.text_max_length or y_len > self.args.encodec_sr*self.args.audio_max_length:
index = random.choice(range(len(self))) # regenerate an index
x, y = self._load_phn_enc(index)
x_len, y_len = len(x), len(y[0])
### padding and cropping below ###
### padding and cropping below ###
# adjust the length of encodec codes, pad to max_len or randomly crop
orig_y_len = copy.copy(y_len)
max_len = int(self.args.audio_max_length * self.args.encodec_sr)
if y_len > max_len:
audio_start = random.choice(range(0, y_len-max_len))
for i in range(len(y)):
y[i] = y[i][audio_start:(audio_start+max_len)]
y_len = max_len
else:
audio_start = 0
if not self.args.dynamic_batching:
pad = [0] * (max_len - y_len) if self.args.sep_special_token else [self.args.audio_pad_token] * (max_len - y_len)
for i in range(len(y)):
y[i] = y[i] + pad
# adjust text
# if audio is cropped, and text is longer than max, crop max based on how audio is cropped
if audio_start > 0 and len(x) > self.args.text_max_length: # if audio is longer than max and text is long than max, start text the way audio started
x = x[int(len(x)*audio_start/orig_y_len):]
if len(x) > self.args.text_max_length: # if text is still longer than max, cut the end
x = x[:self.args.text_max_length]
x_len = len(x)
if x_len > self.args.text_max_length:
text_start = random.choice(range(0, x_len - self.args.text_max_length))
x = x[text_start:text_start+self.args.text_max_length]
x_len = self.args.text_max_length
elif self.args.pad_x and x_len <= self.args.text_max_length:
pad = [0] * (self.args.text_max_length - x_len) if self.args.sep_special_token else [self.args.text_pad_token] * (self.args.text_max_length - x_len)
x = x + pad
### padding and cropping above ###
### padding and cropping above ###
return {
"x": torch.LongTensor(x),
"x_len": x_len,
"y": torch.LongTensor(y),
"y_len": y_len
}
def collate(self, batch):
out = {key:[] for key in batch[0]}
for item in batch:
if item['x'] == None: # deal with load failure
continue
for key, val in item.items():
out[key].append(val)
res = {}
if self.args.pad_x:
res["x"] = torch.stack(out["x"], dim=0)
else:
res["x"] = torch.nn.utils.rnn.pad_sequence(out["x"], batch_first=True, padding_value=0 if self.args.sep_special_token else self.args.text_pad_token)
res["x_lens"] = torch.LongTensor(out["x_len"])
if self.args.dynamic_batching:
if out['y'][0].ndim==2:
res['y'] = torch.nn.utils.rnn.pad_sequence([item.transpose(1,0) for item in out['y']],padding_value=0 if self.args.sep_special_token else self.args.audio_pad_token)
res['y'] = res['y'].permute(1,2,0) # T B K -> B K T
else:
assert out['y'][0].ndim==1, out['y'][0].shape
res['y'] = torch.nn.utils.rnn.pad_sequence(out['y'], batch_first=True, padding_value=0 if self.args.sep_special_token else self.args.audio_pad_token)
else:
res['y'] = torch.stack(out['y'], dim=0)
res["y_lens"] = torch.LongTensor(out["y_len"])
res["text_padding_mask"] = torch.arange(res['x'][0].shape[-1]).unsqueeze(0) >= res['x_lens'].unsqueeze(1)
res["audio_padding_mask"] = torch.arange(res['y'][0].shape[-1]).unsqueeze(0) >= res['y_lens'].unsqueeze(1)
return res

149
data/tokenizer.py Normal file
View File

@@ -0,0 +1,149 @@
# cp from https://github.com/lifeiteng/vall-e/blob/main/valle/data/tokenizer.py
# Copyright 2023 (authors: Feiteng Li)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from dataclasses import asdict, dataclass
from typing import Any, Dict, List, Optional, Pattern, Union
import numpy as np
import torch
import torchaudio
# from lhotse.features import FeatureExtractor
# from lhotse.utils import Seconds, compute_num_frames
from phonemizer.backend import EspeakBackend
from phonemizer.backend.espeak.language_switch import LanguageSwitch
from phonemizer.backend.espeak.words_mismatch import WordMismatch
from phonemizer.punctuation import Punctuation
from phonemizer.separator import Separator
class TextTokenizer:
"""Phonemize Text."""
def __init__(
self,
language="en-us",
backend="espeak",
separator=Separator(word="_", syllable="-", phone="|"),
preserve_punctuation=True,
punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
with_stress: bool = False,
tie: Union[bool, str] = False,
language_switch: LanguageSwitch = "keep-flags",
words_mismatch: WordMismatch = "ignore",
) -> None:
phonemizer = EspeakBackend(
language,
punctuation_marks=punctuation_marks,
preserve_punctuation=preserve_punctuation,
with_stress=with_stress,
tie=tie,
language_switch=language_switch,
words_mismatch=words_mismatch,
)
self.backend = phonemizer
self.separator = separator
def to_list(self, phonemized: str) -> List[str]:
fields = []
for word in phonemized.split(self.separator.word):
# "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z.
pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE)
fields.extend(
[p for p in pp if p != self.separator.phone]
+ [self.separator.word]
)
assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count(
self.separator.phone
)
return fields[:-1]
def __call__(self, text, strip=True) -> List[List[str]]:
if isinstance(text, str):
text = [text]
phonemized = self.backend.phonemize(
text, separator=self.separator, strip=strip, njobs=1
)
return [self.to_list(p) for p in phonemized]
def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]:
phonemes = tokenizer([text.strip()])
return phonemes[0] # k2symbols
def convert_audio(wav: torch.Tensor, sr: int, target_sr: int, target_channels: int):
assert wav.shape[0] in [1, 2], "Audio must be mono or stereo."
if target_channels == 1:
wav = wav.mean(0, keepdim=True)
elif target_channels == 2:
*shape, _, length = wav.shape
wav = wav.expand(*shape, target_channels, length)
elif wav.shape[0] == 1:
wav = wav.expand(target_channels, -1)
wav = torchaudio.transforms.Resample(sr, target_sr)(wav)
return wav
class AudioTokenizer:
"""EnCodec audio."""
def __init__(
self,
device: Any = None,
signature = None
) -> None:
from audiocraft.solvers import CompressionSolver
model = CompressionSolver.model_from_checkpoint(signature)
self.sample_rate = model.sample_rate
self.channels = model.channels
if not device:
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda:0")
self._device = device
self.codec = model.to(device)
@property
def device(self):
return self._device
def encode(self, wav: torch.Tensor) -> torch.Tensor:
codes = self.codec.encode(wav.to(self.device))
return [(codes[0], None)]
def decode(self, frames: torch.Tensor) -> torch.Tensor:
frames = frames[0][0] # [1,4,T]
return self.codec.decode(frames)
def tokenize_audio(tokenizer: AudioTokenizer, audio_path: str, offset = -1, num_frames=-1):
# Load and pre-process the audio waveform
if offset != -1 and num_frames!=-1:
wav, sr = torchaudio.load(audio_path, frame_offset=offset, num_frames=num_frames)
else:
wav, sr = torchaudio.load(audio_path)
wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels)
wav = wav.unsqueeze(0)
# Extract discrete codes from EnCodec
with torch.no_grad():
encoded_frames = tokenizer.encode(wav)
return encoded_frames