mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
[to #42322933] Add cv-action-recongnition-pipeline run inference with the cpu
行为识别推理同时支持CPU和GPU
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9253438
This commit is contained in:
@@ -32,7 +32,9 @@ class ActionRecognitionPipeline(Pipeline):
|
||||
config_path = osp.join(self.model, ModelFile.CONFIGURATION)
|
||||
logger.info(f'loading config from {config_path}')
|
||||
self.cfg = Config.from_file(config_path)
|
||||
self.infer_model = BaseVideoModel(cfg=self.cfg).cuda()
|
||||
self.device = torch.device(
|
||||
'cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.infer_model = BaseVideoModel(cfg=self.cfg).to(self.device)
|
||||
self.infer_model.eval()
|
||||
self.infer_model.load_state_dict(torch.load(model_path)['model_state'])
|
||||
self.label_mapping = self.cfg.label_mapping
|
||||
@@ -40,7 +42,7 @@ class ActionRecognitionPipeline(Pipeline):
|
||||
|
||||
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||
if isinstance(input, str):
|
||||
video_input_data = ReadVideoData(self.cfg, input).cuda()
|
||||
video_input_data = ReadVideoData(self.cfg, input).to(self.device)
|
||||
else:
|
||||
raise TypeError(f'input should be a str,'
|
||||
f' but got {type(input)}')
|
||||
|
||||
Reference in New Issue
Block a user