diff --git a/modelscope/models/multi_modal/video_synthesis/text_to_video_synthesis_model.py b/modelscope/models/multi_modal/video_synthesis/text_to_video_synthesis_model.py index 0ec66069..76f30580 100644 --- a/modelscope/models/multi_modal/video_synthesis/text_to_video_synthesis_model.py +++ b/modelscope/models/multi_modal/video_synthesis/text_to_video_synthesis_model.py @@ -58,7 +58,7 @@ class TextToVideoSynthesis(Model): `True`. """ super().__init__(model_dir=model_dir, *args, **kwargs) - self.device = torch.device('cuda') if torch.cuda.is_available() \ + self.device = torch.device(kwargs.get('device', 'cuda')) if torch.cuda.is_available() \ else torch.device('cpu') self.config = Config.from_file( osp.join(model_dir, ModelFile.CONFIGURATION))