From 686e9a4d4e7857cb39c65dacf3fb2fc9eae3b422 Mon Sep 17 00:00:00 2001 From: "jiaqi.sjq" Date: Fri, 17 Feb 2023 14:16:45 +0800 Subject: [PATCH] [to #41669377] bugfix: fix load missing map_location Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11700872 --- modelscope/models/audio/tts/voice.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/modelscope/models/audio/tts/voice.py b/modelscope/models/audio/tts/voice.py index 765f6d83..b7b91a9e 100644 --- a/modelscope/models/audio/tts/voice.py +++ b/modelscope/models/audio/tts/voice.py @@ -101,15 +101,18 @@ class Voice: def __load_am(self): self.__am_model, _, _ = model_builder(self.__am_config, self.__device) self.__am = self.__am_model['KanTtsSAMBERT'] - state_dict = torch.load(self.__am_ckpts[next( - reversed(self.__am_ckpts))]) + state_dict = torch.load( + self.__am_ckpts[next(reversed(self.__am_ckpts))], + map_location=self.__device) self.__am.load_state_dict(state_dict['model'], strict=False) self.__am.eval() def __load_vocoder(self): self.__voc_model = Generator( **self.__voc_config['Model']['Generator']['params']) - states = torch.load(self.__voc_ckpts[next(reversed(self.__voc_ckpts))]) + states = torch.load( + self.__voc_ckpts[next(reversed(self.__voc_ckpts))], + map_location=self.__device) self.__voc_model.load_state_dict(states['model']['generator']) if self.__voc_config['Model']['Generator']['params'][ 'out_channels'] > 1: