mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 04:29:22 +01:00
[to #41669377] bugfix: fix load missing map_location
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11700872
This commit is contained in:
@@ -101,15 +101,18 @@ class Voice:
|
||||
def __load_am(self):
|
||||
self.__am_model, _, _ = model_builder(self.__am_config, self.__device)
|
||||
self.__am = self.__am_model['KanTtsSAMBERT']
|
||||
state_dict = torch.load(self.__am_ckpts[next(
|
||||
reversed(self.__am_ckpts))])
|
||||
state_dict = torch.load(
|
||||
self.__am_ckpts[next(reversed(self.__am_ckpts))],
|
||||
map_location=self.__device)
|
||||
self.__am.load_state_dict(state_dict['model'], strict=False)
|
||||
self.__am.eval()
|
||||
|
||||
def __load_vocoder(self):
|
||||
self.__voc_model = Generator(
|
||||
**self.__voc_config['Model']['Generator']['params'])
|
||||
states = torch.load(self.__voc_ckpts[next(reversed(self.__voc_ckpts))])
|
||||
states = torch.load(
|
||||
self.__voc_ckpts[next(reversed(self.__voc_ckpts))],
|
||||
map_location=self.__device)
|
||||
self.__voc_model.load_state_dict(states['model']['generator'])
|
||||
if self.__voc_config['Model']['Generator']['params'][
|
||||
'out_channels'] > 1:
|
||||
|
||||
Reference in New Issue
Block a user