Update customtokenizer.py to latest version.

This commit is contained in:
Mylo
2023-05-26 22:38:39 +02:00
committed by GitHub
parent f95ba00756
commit 6e4d02bfbf

View File

@@ -1,4 +1,8 @@
# From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer
"""
Custom tokenizer model.
Author: https://www.github.com/gitmylo/
License: MIT
"""
import json
import os.path
@@ -91,7 +95,7 @@ class CustomTokenizer(nn.Module):
optimizer.step()
def save(self, path):
info_path = os.path.basename(path) + '/.info'
info_path = '.'.join(os.path.basename(path).split('.')[:-1]) + '/.info'
torch.save(self.state_dict(), path)
data_from_model = Data(self.input_size, self.hidden_size, self.output_size, self.version)
with ZipFile(path, 'a') as model_zip:
@@ -112,7 +116,9 @@ class CustomTokenizer(nn.Module):
model = CustomTokenizer()
else:
model = CustomTokenizer(data_from_model.hidden_size, data_from_model.input_size, data_from_model.output_size, data_from_model.version)
model.load_state_dict(torch.load(path, map_location))
model.load_state_dict(torch.load(path))
if map_location:
model = model.to(map_location)
return model