From 6e4d02bfbf14aebcd22772a0a2a30791f99c4531 Mon Sep 17 00:00:00 2001 From: Mylo <36931363+gitmylo@users.noreply.github.com> Date: Fri, 26 May 2023 22:38:39 +0200 Subject: [PATCH] Update customtokenizer.py to latest version. --- hubert/customtokenizer.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/hubert/customtokenizer.py b/hubert/customtokenizer.py index 7f807d3..c1a4a51 100644 --- a/hubert/customtokenizer.py +++ b/hubert/customtokenizer.py @@ -1,4 +1,8 @@ -# From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer +""" +Custom tokenizer model. +Author: https://www.github.com/gitmylo/ +License: MIT +""" import json import os.path @@ -91,7 +95,7 @@ class CustomTokenizer(nn.Module): optimizer.step() def save(self, path): - info_path = os.path.basename(path) + '/.info' + info_path = '.'.join(os.path.basename(path).split('.')[:-1]) + '/.info' torch.save(self.state_dict(), path) data_from_model = Data(self.input_size, self.hidden_size, self.output_size, self.version) with ZipFile(path, 'a') as model_zip: @@ -112,7 +116,9 @@ class CustomTokenizer(nn.Module): model = CustomTokenizer() else: model = CustomTokenizer(data_from_model.hidden_size, data_from_model.input_size, data_from_model.output_size, data_from_model.version) - model.load_state_dict(torch.load(path, map_location)) + model.load_state_dict(torch.load(path)) + if map_location: + model = model.to(map_location) return model