mirror of
https://github.com/jasonppy/VoiceCraft.git
synced 2025-12-16 11:57:44 +01:00
init
This commit is contained in:
0
data/__init__.py
Normal file
0
data/__init__.py
Normal file
160
data/giga_preprocessing/encodec_encode.py
Normal file
160
data/giga_preprocessing/encodec_encode.py
Normal file
@@ -0,0 +1,160 @@
|
||||
import argparse
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="encode the librilight dataset using encodec model")
|
||||
parser.add_argument("--manifest_root", type=str, default="/home/pyp/audiocraft/egs/gigaspeech", help="this the dir of the audiocraft manifest!")
|
||||
parser.add_argument('--audio_dir', type=str, default="/data/scratch/pyp/datasets/gigaspeech_flac", help="Path dirs of the flac audio files")
|
||||
parser.add_argument('--save_dir', type=str, default="/data/scratch/pyp/datasets/gigaspeech_phn_enc_manifest/xl", help="path to the manifest, phonemes, and encodec codes dirs")
|
||||
parser.add_argument('--encodec_model_path', type=str, default="/data/scratch/pyp/exp_pyp/audiocraft/encodec/xps/6f79c6a8/checkpoint.th")
|
||||
parser.add_argument('--n_workers', type=int, default=32, help="Number of parallel worker processes")
|
||||
parser.add_argument('--batch_size', type=int, default=64, help="batch size for encodec encoding, decrease it if OOM. This is the sum of batch size *over each gpu*, so increase it if you are using more gpus")
|
||||
parser.add_argument('--model_sr', type=int, default=16000, help='encodec input audio sample rate')
|
||||
parser.add_argument('--downsample_rate', type=int, default=320, help='encodec downsample rate')
|
||||
parser.add_argument('--model_code_sr', type=int, default=50, help='encodec model code sample rate')
|
||||
parser.add_argument('--len_cap', type=float, default=35.0, help='will drop audios that are longer than this number')
|
||||
return parser.parse_args()
|
||||
|
||||
if __name__ == "__main__":
|
||||
import logging
|
||||
formatter = (
|
||||
"%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s"
|
||||
)
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
import tqdm
|
||||
import time
|
||||
|
||||
args = parse_args()
|
||||
|
||||
manifest_dir = args.manifest_root # this dir is scp-ed
|
||||
audio_dir = args.audio_dir # this is scp-ed flac dir
|
||||
encodec_signature = args.encodec_model_path.split("/")[-2]
|
||||
save_codes_dir = os.path.join(args.save_dir, f"encodec_16khz_{encodec_signature}")
|
||||
os.makedirs(save_codes_dir, exist_ok=True)
|
||||
|
||||
|
||||
# model_sr = 16000
|
||||
# downsample_rate = 320
|
||||
# model_code_sr = 50
|
||||
def sort_by_audio_len(lens):
|
||||
inds = np.argsort(lens).tolist()
|
||||
logging.info(f"longest: {lens[inds[-1]]/args.downsample_rate} encodec codes, {lens[inds[-1]]/args.model_sr:.2f} sec.")
|
||||
logging.info(f"shortest: {lens[inds[0]]/args.downsample_rate} encodec codes, {lens[inds[0]]/args.model_sr:.2f} sec.")
|
||||
logging.info(f"median: {lens[inds[len(inds)//2]]/args.downsample_rate} encodec codes, {lens[inds[len(inds)//2]]/args.model_sr:.2f} sec.")
|
||||
logging.info(f"95 percentile longest: {lens[inds[int(len(inds)*0.95)]]/args.downsample_rate} encodec codes, {lens[inds[int(len(inds)*0.95)]]/args.model_sr:.2f} sec.")
|
||||
return inds[::-1]
|
||||
|
||||
def write_array_to_txt_file(array, filename):
|
||||
with open(filename, 'w') as f:
|
||||
for a in array[:-1]:
|
||||
f.write(' '.join(map(str, a))+'\n')
|
||||
f.write(' '.join(map(str, array[-1])))
|
||||
|
||||
|
||||
|
||||
class mydataset(torch.utils.data.Dataset):
|
||||
def __init__(self, split):
|
||||
super().__init__()
|
||||
# self.data = gs[split]
|
||||
self.split = split
|
||||
self.audio_root = audio_dir
|
||||
manifest_fn = os.path.join(manifest_dir, split+".txt")
|
||||
with open(manifest_fn, "r") as rf:
|
||||
self.data = [l.strip().split("\t") for l in rf.readlines()]
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
def __getitem__(self, ind):
|
||||
try:
|
||||
afn = self.data[ind][0]
|
||||
fn = os.path.join(self.audio_root, afn)
|
||||
audio, sr = torchaudio.load(fn)
|
||||
assert sr == args.model_sr, sr
|
||||
except Exception as e:
|
||||
logging.info(f"{e}")
|
||||
return None, None, None
|
||||
assert audio.ndim==2 and audio.shape[0] == 1, audio.shape
|
||||
return audio.type(torch.float32).squeeze(0), audio.shape[-1], os.path.basename(afn).split(".")[0]
|
||||
def collate(self, batch):
|
||||
lens, audios, segment_ids = [], [], []
|
||||
for item in batch:
|
||||
if item[0] != None:
|
||||
audios.append(item[0])
|
||||
lens.append(item[1])
|
||||
segment_ids.append(item[2])
|
||||
return audios, lens, segment_ids
|
||||
|
||||
# load the encodec model
|
||||
from audiocraft.solvers import CompressionSolver
|
||||
model = CompressionSolver.model_from_checkpoint(args.encodec_model_path)
|
||||
model = model.cuda()
|
||||
model = model.eval()
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
|
||||
# setup dataloader
|
||||
mega_batch_size = 2100
|
||||
batch_size = args.batch_size
|
||||
train_dataset = mydataset('train')
|
||||
train_loader = torch.torch.utils.data.DataLoader(train_dataset, batch_size=mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=train_dataset.collate)
|
||||
validation_dataset = mydataset('validation')
|
||||
validation_loader = torch.torch.utils.data.DataLoader(validation_dataset, batch_size=mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=validation_dataset.collate)
|
||||
test_dataset = mydataset('test')
|
||||
test_loader = torch.torch.utils.data.DataLoader(test_dataset, batch_size=mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=test_dataset.collate)
|
||||
splits = ['validation', 'test', 'train']
|
||||
loaders = [validation_loader, test_loader, train_loader]
|
||||
# splits = ['validation'] # NOTE this is for debug, for example, see if the
|
||||
# loaders = [validation_loader]
|
||||
for split, loader in zip(splits, loaders):
|
||||
skip = 0
|
||||
logging.info(f"now processing split {split}...")
|
||||
mega_n_steps = int(np.ceil(len(loader.dataset) / mega_batch_size))
|
||||
# mega_n_steps = int(np.ceil(len(gs) / mega_batch_size))
|
||||
logging.info(f"partition the split {split} into {mega_n_steps} parts, each has {mega_batch_size} samples")
|
||||
# with open(mani_fn, "a") as mani_wf: # resume from where we failed
|
||||
for m, mega_batch in enumerate(loader):
|
||||
logging.info(f"====================================")
|
||||
logging.info(f"====================================")
|
||||
logging.info(f"now processing mega step {m+1}/{mega_n_steps}")
|
||||
lengths = np.array(mega_batch[1])
|
||||
sorted_inds = sort_by_audio_len(lengths)
|
||||
for j in range(len(sorted_inds))[::-1]:
|
||||
if lengths[sorted_inds[j]] < args.model_sr*0.2 or lengths[sorted_inds[j]] > args.model_sr*args.len_cap: # skip samples that are too short (shorter than 0.2s), or too big (bigger than 80s)
|
||||
skip += 1
|
||||
del sorted_inds[j]
|
||||
|
||||
n_steps = int(np.ceil(len(sorted_inds) / batch_size))
|
||||
for n in tqdm.tqdm(range(n_steps), disable=True):
|
||||
inds_used = sorted_inds[n*batch_size:(n+1)*batch_size]
|
||||
wav_batch = [mega_batch[0][id] for id in inds_used]
|
||||
all_lens = [mega_batch[1][id] for id in inds_used]
|
||||
segment_id_batch = [mega_batch[2][id] for id in inds_used]
|
||||
# print(segment_id_batch)
|
||||
padded_wav = torch.nn.utils.rnn.pad_sequence(wav_batch, batch_first=True).unsqueeze(1) # [B, T] -> [B, 1, T]
|
||||
with torch.no_grad():
|
||||
if max(all_lens) > 300000 and len(all_lens) > 1: # NOTE decrease this (300000) if OOM, or chunk it into more than 2 forward passes
|
||||
codes = []
|
||||
inwav = padded_wav.cuda()
|
||||
codes.append(model(inwav[:len(inwav)//2], encode=True)[0].cpu())
|
||||
codes.append(model(inwav[len(inwav)//2:], encode=True)[0].cpu())
|
||||
codes = torch.cat(codes, dim=0)
|
||||
else:
|
||||
encoded_frames = model(padded_wav.cuda(), encode=True) # wav needs to have shape [B, C, T], C is model.channels, which is 1 for the 24kHz encodec model
|
||||
# logging.info(f"encoded_frames: {encoded_frames[0].shape}")
|
||||
codes = encoded_frames[0].cpu()
|
||||
|
||||
for i, length in enumerate(all_lens):
|
||||
save_fn = os.path.join(save_codes_dir, segment_id_batch[i]+".txt")
|
||||
actual_len = round(length / args.downsample_rate) # 320 is downsample rate for this model
|
||||
cur_code = codes[i].tolist() if type(codes) == list else codes[i, :, :actual_len].tolist()
|
||||
write_array_to_txt_file(cur_code, save_fn)
|
||||
|
||||
# mani_wf.write(f"0\t{segment_id_batch[i]}\t{len(cur_code[0])}\n") # write to manifest file
|
||||
# if i == 10:
|
||||
# raise
|
||||
# break
|
||||
# logging.info(f"split {split} has {len(gs[split])} samples in total, skipped {skip} due to forbiden words")
|
||||
logging.info(f"split {split} has {len(loader.dataset)} samples in total, skipped {skip} due to utterance being too long or too short")
|
||||
# break
|
||||
158
data/gigaspeech.py
Normal file
158
data/gigaspeech.py
Normal file
@@ -0,0 +1,158 @@
|
||||
import os
|
||||
import torch
|
||||
import random
|
||||
import copy
|
||||
import logging
|
||||
import shutil
|
||||
|
||||
class dataset(torch.utils.data.Dataset):
|
||||
def __init__(self, args, split):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.split = split
|
||||
assert self.split in ['train', 'validation', 'test']
|
||||
manifest_fn = os.path.join(self.args.dataset_dir, self.args.manifest_name, self.split+".txt")
|
||||
|
||||
with open(manifest_fn, "r") as rf:
|
||||
data = [l.strip().split("\t") for l in rf.readlines()]
|
||||
lengths_list = [int(item[-1]) for item in data]
|
||||
self.data = []
|
||||
self.lengths_list = []
|
||||
for d, l in zip(data, lengths_list):
|
||||
if l >= self.args.encodec_sr*self.args.audio_min_length:
|
||||
if self.args.drop_long and l > self.args.encodec_sr*self.args.audio_max_length:
|
||||
continue
|
||||
self.data.append(d)
|
||||
self.lengths_list.append(l)
|
||||
logging.info(f"number of data points for {self.split} split: {len(self.lengths_list)}")
|
||||
|
||||
# phoneme vocabulary
|
||||
vocab_fn = os.path.join(self.args.dataset_dir,"vocab.txt")
|
||||
shutil.copy(vocab_fn, os.path.join(self.args.exp_dir, "vocab.txt"))
|
||||
with open(vocab_fn, "r") as f:
|
||||
temp = [l.strip().split(" ") for l in f.readlines() if len(l) != 0]
|
||||
self.phn2num = {item[1]:int(item[0]) for item in temp}
|
||||
|
||||
self.symbol_set = set(["<SIL>", "<MUSIC>", "<NOISE>", "<OTHER>"])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.lengths_list)
|
||||
|
||||
def _load_phn_enc(self, index):
|
||||
item = self.data[index]
|
||||
pf = os.path.join(self.args.dataset_dir, self.args.phn_folder_name, item[1]+".txt")
|
||||
ef = os.path.join(self.args.dataset_dir, self.args.encodec_folder_name, item[1]+".txt")
|
||||
try:
|
||||
with open(pf, "r") as p, open(ef, "r") as e:
|
||||
phns = [l.strip() for l in p.readlines()]
|
||||
assert len(phns) == 1, phns
|
||||
x = [self.phn2num[item] for item in phns[0].split(" ") if item not in self.symbol_set] # drop ["<SIL>", "<MUSIC>", "<NOISE>", "<OTHER>"], as they are not in training set annotation
|
||||
encos = [l.strip().split() for k, l in enumerate(e.readlines()) if k < self.args.n_codebooks]
|
||||
|
||||
assert len(encos) == self.args.n_codebooks, ef
|
||||
if self.args.special_first:
|
||||
y = [[int(n)+self.args.n_special for n in l] for l in encos]
|
||||
else:
|
||||
y = [[int(n) for n in l] for l in encos]
|
||||
if self.args.training_stage == 1 and not self.args.valle and not (self.args.musicgen or self.args.valle_orig):
|
||||
y = y[:1]
|
||||
except Exception as e:
|
||||
logging.info(f"loading failed for {pf} and {ef}, maybe files don't exist or are corrupted")
|
||||
logging.info(f"error message: {e}")
|
||||
return [], [[]]
|
||||
|
||||
return x, y
|
||||
|
||||
def __getitem__(self, index):
|
||||
x, y = self._load_phn_enc(index)
|
||||
x_len, y_len = len(x), len(y[0])
|
||||
|
||||
if x_len == 0 or y_len == 0:
|
||||
return {
|
||||
"x": None,
|
||||
"x_len": None,
|
||||
"y": None,
|
||||
"y_len": None,
|
||||
"y_mask_interval": None, # index y_mask_interval[1] is the position of start_of_continue token
|
||||
"extra_mask_start": None # this is only used in VE1
|
||||
}
|
||||
while y_len < self.args.encodec_sr*self.args.audio_min_length:
|
||||
assert not self.args.dynamic_batching
|
||||
index = random.choice(range(len(self))) # regenerate an index
|
||||
x, y = self._load_phn_enc(index)
|
||||
x_len, y_len = len(x), len(y[0])
|
||||
if self.args.drop_long:
|
||||
while x_len > self.args.text_max_length or y_len > self.args.encodec_sr*self.args.audio_max_length:
|
||||
index = random.choice(range(len(self))) # regenerate an index
|
||||
x, y = self._load_phn_enc(index)
|
||||
x_len, y_len = len(x), len(y[0])
|
||||
|
||||
### padding and cropping below ###
|
||||
### padding and cropping below ###
|
||||
# adjust the length of encodec codes, pad to max_len or randomly crop
|
||||
orig_y_len = copy.copy(y_len)
|
||||
max_len = int(self.args.audio_max_length * self.args.encodec_sr)
|
||||
if y_len > max_len:
|
||||
audio_start = random.choice(range(0, y_len-max_len))
|
||||
for i in range(len(y)):
|
||||
y[i] = y[i][audio_start:(audio_start+max_len)]
|
||||
y_len = max_len
|
||||
else:
|
||||
audio_start = 0
|
||||
if not self.args.dynamic_batching:
|
||||
pad = [0] * (max_len - y_len) if self.args.sep_special_token else [self.args.audio_pad_token] * (max_len - y_len)
|
||||
for i in range(len(y)):
|
||||
y[i] = y[i] + pad
|
||||
|
||||
# adjust text
|
||||
# if audio is cropped, and text is longer than max, crop max based on how audio is cropped
|
||||
if audio_start > 0 and len(x) > self.args.text_max_length: # if audio is longer than max and text is long than max, start text the way audio started
|
||||
x = x[int(len(x)*audio_start/orig_y_len):]
|
||||
if len(x) > self.args.text_max_length: # if text is still longer than max, cut the end
|
||||
x = x[:self.args.text_max_length]
|
||||
|
||||
x_len = len(x)
|
||||
if x_len > self.args.text_max_length:
|
||||
text_start = random.choice(range(0, x_len - self.args.text_max_length))
|
||||
x = x[text_start:text_start+self.args.text_max_length]
|
||||
x_len = self.args.text_max_length
|
||||
elif self.args.pad_x and x_len <= self.args.text_max_length:
|
||||
pad = [0] * (self.args.text_max_length - x_len) if self.args.sep_special_token else [self.args.text_pad_token] * (self.args.text_max_length - x_len)
|
||||
x = x + pad
|
||||
### padding and cropping above ###
|
||||
### padding and cropping above ###
|
||||
|
||||
return {
|
||||
"x": torch.LongTensor(x),
|
||||
"x_len": x_len,
|
||||
"y": torch.LongTensor(y),
|
||||
"y_len": y_len
|
||||
}
|
||||
|
||||
|
||||
def collate(self, batch):
|
||||
out = {key:[] for key in batch[0]}
|
||||
for item in batch:
|
||||
if item['x'] == None: # deal with load failure
|
||||
continue
|
||||
for key, val in item.items():
|
||||
out[key].append(val)
|
||||
res = {}
|
||||
if self.args.pad_x:
|
||||
res["x"] = torch.stack(out["x"], dim=0)
|
||||
else:
|
||||
res["x"] = torch.nn.utils.rnn.pad_sequence(out["x"], batch_first=True, padding_value=0 if self.args.sep_special_token else self.args.text_pad_token)
|
||||
res["x_lens"] = torch.LongTensor(out["x_len"])
|
||||
if self.args.dynamic_batching:
|
||||
if out['y'][0].ndim==2:
|
||||
res['y'] = torch.nn.utils.rnn.pad_sequence([item.transpose(1,0) for item in out['y']],padding_value=0 if self.args.sep_special_token else self.args.audio_pad_token)
|
||||
res['y'] = res['y'].permute(1,2,0) # T B K -> B K T
|
||||
else:
|
||||
assert out['y'][0].ndim==1, out['y'][0].shape
|
||||
res['y'] = torch.nn.utils.rnn.pad_sequence(out['y'], batch_first=True, padding_value=0 if self.args.sep_special_token else self.args.audio_pad_token)
|
||||
else:
|
||||
res['y'] = torch.stack(out['y'], dim=0)
|
||||
res["y_lens"] = torch.LongTensor(out["y_len"])
|
||||
res["text_padding_mask"] = torch.arange(res['x'][0].shape[-1]).unsqueeze(0) >= res['x_lens'].unsqueeze(1)
|
||||
res["audio_padding_mask"] = torch.arange(res['y'][0].shape[-1]).unsqueeze(0) >= res['y_lens'].unsqueeze(1)
|
||||
return res
|
||||
149
data/tokenizer.py
Normal file
149
data/tokenizer.py
Normal file
@@ -0,0 +1,149 @@
|
||||
# cp from https://github.com/lifeiteng/vall-e/blob/main/valle/data/tokenizer.py
|
||||
# Copyright 2023 (authors: Feiteng Li)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Any, Dict, List, Optional, Pattern, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
# from lhotse.features import FeatureExtractor
|
||||
# from lhotse.utils import Seconds, compute_num_frames
|
||||
from phonemizer.backend import EspeakBackend
|
||||
from phonemizer.backend.espeak.language_switch import LanguageSwitch
|
||||
from phonemizer.backend.espeak.words_mismatch import WordMismatch
|
||||
from phonemizer.punctuation import Punctuation
|
||||
from phonemizer.separator import Separator
|
||||
|
||||
|
||||
|
||||
class TextTokenizer:
|
||||
"""Phonemize Text."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
language="en-us",
|
||||
backend="espeak",
|
||||
separator=Separator(word="_", syllable="-", phone="|"),
|
||||
preserve_punctuation=True,
|
||||
punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
|
||||
with_stress: bool = False,
|
||||
tie: Union[bool, str] = False,
|
||||
language_switch: LanguageSwitch = "keep-flags",
|
||||
words_mismatch: WordMismatch = "ignore",
|
||||
) -> None:
|
||||
phonemizer = EspeakBackend(
|
||||
language,
|
||||
punctuation_marks=punctuation_marks,
|
||||
preserve_punctuation=preserve_punctuation,
|
||||
with_stress=with_stress,
|
||||
tie=tie,
|
||||
language_switch=language_switch,
|
||||
words_mismatch=words_mismatch,
|
||||
)
|
||||
|
||||
self.backend = phonemizer
|
||||
self.separator = separator
|
||||
|
||||
def to_list(self, phonemized: str) -> List[str]:
|
||||
fields = []
|
||||
for word in phonemized.split(self.separator.word):
|
||||
# "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z.
|
||||
pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE)
|
||||
fields.extend(
|
||||
[p for p in pp if p != self.separator.phone]
|
||||
+ [self.separator.word]
|
||||
)
|
||||
assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count(
|
||||
self.separator.phone
|
||||
)
|
||||
return fields[:-1]
|
||||
|
||||
def __call__(self, text, strip=True) -> List[List[str]]:
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
|
||||
phonemized = self.backend.phonemize(
|
||||
text, separator=self.separator, strip=strip, njobs=1
|
||||
)
|
||||
return [self.to_list(p) for p in phonemized]
|
||||
|
||||
|
||||
def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]:
|
||||
phonemes = tokenizer([text.strip()])
|
||||
return phonemes[0] # k2symbols
|
||||
|
||||
def convert_audio(wav: torch.Tensor, sr: int, target_sr: int, target_channels: int):
|
||||
assert wav.shape[0] in [1, 2], "Audio must be mono or stereo."
|
||||
if target_channels == 1:
|
||||
wav = wav.mean(0, keepdim=True)
|
||||
elif target_channels == 2:
|
||||
*shape, _, length = wav.shape
|
||||
wav = wav.expand(*shape, target_channels, length)
|
||||
elif wav.shape[0] == 1:
|
||||
wav = wav.expand(target_channels, -1)
|
||||
wav = torchaudio.transforms.Resample(sr, target_sr)(wav)
|
||||
return wav
|
||||
|
||||
class AudioTokenizer:
|
||||
"""EnCodec audio."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: Any = None,
|
||||
signature = None
|
||||
) -> None:
|
||||
from audiocraft.solvers import CompressionSolver
|
||||
model = CompressionSolver.model_from_checkpoint(signature)
|
||||
self.sample_rate = model.sample_rate
|
||||
self.channels = model.channels
|
||||
|
||||
if not device:
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
self._device = device
|
||||
|
||||
self.codec = model.to(device)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self._device
|
||||
|
||||
def encode(self, wav: torch.Tensor) -> torch.Tensor:
|
||||
codes = self.codec.encode(wav.to(self.device))
|
||||
return [(codes[0], None)]
|
||||
|
||||
def decode(self, frames: torch.Tensor) -> torch.Tensor:
|
||||
frames = frames[0][0] # [1,4,T]
|
||||
return self.codec.decode(frames)
|
||||
|
||||
|
||||
|
||||
def tokenize_audio(tokenizer: AudioTokenizer, audio_path: str, offset = -1, num_frames=-1):
|
||||
# Load and pre-process the audio waveform
|
||||
if offset != -1 and num_frames!=-1:
|
||||
wav, sr = torchaudio.load(audio_path, frame_offset=offset, num_frames=num_frames)
|
||||
else:
|
||||
wav, sr = torchaudio.load(audio_path)
|
||||
wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels)
|
||||
wav = wav.unsqueeze(0)
|
||||
|
||||
# Extract discrete codes from EnCodec
|
||||
with torch.no_grad():
|
||||
encoded_frames = tokenizer.encode(wav)
|
||||
return encoded_frames
|
||||
Reference in New Issue
Block a user