[to #41669377] bugfix: fix load missing map_location

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11700872
This commit is contained in:
jiaqi.sjq
2023-02-17 14:16:45 +08:00
committed by yingda.chen
parent 9d1a9301a6
commit 686e9a4d4e

View File

@@ -101,15 +101,18 @@ class Voice:
def __load_am(self):
self.__am_model, _, _ = model_builder(self.__am_config, self.__device)
self.__am = self.__am_model['KanTtsSAMBERT']
state_dict = torch.load(self.__am_ckpts[next(
reversed(self.__am_ckpts))])
state_dict = torch.load(
self.__am_ckpts[next(reversed(self.__am_ckpts))],
map_location=self.__device)
self.__am.load_state_dict(state_dict['model'], strict=False)
self.__am.eval()
def __load_vocoder(self):
self.__voc_model = Generator(
**self.__voc_config['Model']['Generator']['params'])
states = torch.load(self.__voc_ckpts[next(reversed(self.__voc_ckpts))])
states = torch.load(
self.__voc_ckpts[next(reversed(self.__voc_ckpts))],
map_location=self.__device)
self.__voc_model.load_state_dict(states['model']['generator'])
if self.__voc_config['Model']['Generator']['params'][
'out_channels'] > 1: