diff --git a/modelscope/pipelines/cv/action_recognition_pipeline.py b/modelscope/pipelines/cv/action_recognition_pipeline.py index fce037d8..8eefa301 100644 --- a/modelscope/pipelines/cv/action_recognition_pipeline.py +++ b/modelscope/pipelines/cv/action_recognition_pipeline.py @@ -32,7 +32,9 @@ class ActionRecognitionPipeline(Pipeline): config_path = osp.join(self.model, ModelFile.CONFIGURATION) logger.info(f'loading config from {config_path}') self.cfg = Config.from_file(config_path) - self.infer_model = BaseVideoModel(cfg=self.cfg).cuda() + self.device = torch.device( + 'cuda' if torch.cuda.is_available() else 'cpu') + self.infer_model = BaseVideoModel(cfg=self.cfg).to(self.device) self.infer_model.eval() self.infer_model.load_state_dict(torch.load(model_path)['model_state']) self.label_mapping = self.cfg.label_mapping @@ -40,7 +42,7 @@ class ActionRecognitionPipeline(Pipeline): def preprocess(self, input: Input) -> Dict[str, Any]: if isinstance(input, str): - video_input_data = ReadVideoData(self.cfg, input).cuda() + video_input_data = ReadVideoData(self.cfg, input).to(self.device) else: raise TypeError(f'input should be a str,' f' but got {type(input)}')