diff --git a/modelscope/models/audio/tts/voice.py b/modelscope/models/audio/tts/voice.py index 765f6d83..b7b91a9e 100644 --- a/modelscope/models/audio/tts/voice.py +++ b/modelscope/models/audio/tts/voice.py @@ -101,15 +101,18 @@ class Voice: def __load_am(self): self.__am_model, _, _ = model_builder(self.__am_config, self.__device) self.__am = self.__am_model['KanTtsSAMBERT'] - state_dict = torch.load(self.__am_ckpts[next( - reversed(self.__am_ckpts))]) + state_dict = torch.load( + self.__am_ckpts[next(reversed(self.__am_ckpts))], + map_location=self.__device) self.__am.load_state_dict(state_dict['model'], strict=False) self.__am.eval() def __load_vocoder(self): self.__voc_model = Generator( **self.__voc_config['Model']['Generator']['params']) - states = torch.load(self.__voc_ckpts[next(reversed(self.__voc_ckpts))]) + states = torch.load( + self.__voc_ckpts[next(reversed(self.__voc_ckpts))], + map_location=self.__device) self.__voc_model.load_state_dict(states['model']['generator']) if self.__voc_config['Model']['Generator']['params'][ 'out_channels'] > 1: