mirror of
https://github.com/serp-ai/bark-with-voice-clone.git
synced 2025-12-15 19:27:57 +01:00
Update customtokenizer.py to latest version.
This commit is contained in:
@@ -1,4 +1,8 @@
|
|||||||
# From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer
|
"""
|
||||||
|
Custom tokenizer model.
|
||||||
|
Author: https://www.github.com/gitmylo/
|
||||||
|
License: MIT
|
||||||
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os.path
|
import os.path
|
||||||
@@ -91,7 +95,7 @@ class CustomTokenizer(nn.Module):
|
|||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
def save(self, path):
|
def save(self, path):
|
||||||
info_path = os.path.basename(path) + '/.info'
|
info_path = '.'.join(os.path.basename(path).split('.')[:-1]) + '/.info'
|
||||||
torch.save(self.state_dict(), path)
|
torch.save(self.state_dict(), path)
|
||||||
data_from_model = Data(self.input_size, self.hidden_size, self.output_size, self.version)
|
data_from_model = Data(self.input_size, self.hidden_size, self.output_size, self.version)
|
||||||
with ZipFile(path, 'a') as model_zip:
|
with ZipFile(path, 'a') as model_zip:
|
||||||
@@ -112,7 +116,9 @@ class CustomTokenizer(nn.Module):
|
|||||||
model = CustomTokenizer()
|
model = CustomTokenizer()
|
||||||
else:
|
else:
|
||||||
model = CustomTokenizer(data_from_model.hidden_size, data_from_model.input_size, data_from_model.output_size, data_from_model.version)
|
model = CustomTokenizer(data_from_model.hidden_size, data_from_model.input_size, data_from_model.output_size, data_from_model.version)
|
||||||
model.load_state_dict(torch.load(path, map_location))
|
model.load_state_dict(torch.load(path))
|
||||||
|
if map_location:
|
||||||
|
model = model.to(map_location)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user