From 82e222ed14bc6703cf378ce024e41edec77ddce4 Mon Sep 17 00:00:00 2001 From: "yongfei.zyf" Date: Mon, 4 Jul 2022 14:03:24 +0800 Subject: [PATCH] [to #42322933] Add cv-action-recongnition-pipeline run inference with the cpu MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 行为识别推理同时支持CPU和GPU Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9253438 --- modelscope/pipelines/cv/action_recognition_pipeline.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/modelscope/pipelines/cv/action_recognition_pipeline.py b/modelscope/pipelines/cv/action_recognition_pipeline.py index fce037d8..8eefa301 100644 --- a/modelscope/pipelines/cv/action_recognition_pipeline.py +++ b/modelscope/pipelines/cv/action_recognition_pipeline.py @@ -32,7 +32,9 @@ class ActionRecognitionPipeline(Pipeline): config_path = osp.join(self.model, ModelFile.CONFIGURATION) logger.info(f'loading config from {config_path}') self.cfg = Config.from_file(config_path) - self.infer_model = BaseVideoModel(cfg=self.cfg).cuda() + self.device = torch.device( + 'cuda' if torch.cuda.is_available() else 'cpu') + self.infer_model = BaseVideoModel(cfg=self.cfg).to(self.device) self.infer_model.eval() self.infer_model.load_state_dict(torch.load(model_path)['model_state']) self.label_mapping = self.cfg.label_mapping @@ -40,7 +42,7 @@ class ActionRecognitionPipeline(Pipeline): def preprocess(self, input: Input) -> Dict[str, Any]: if isinstance(input, str): - video_input_data = ReadVideoData(self.cfg, input).cuda() + video_input_data = ReadVideoData(self.cfg, input).to(self.device) else: raise TypeError(f'input should be a str,' f' but got {type(input)}')