diff --git a/modelscope/models/multi_modal/image_captioning_model.py b/modelscope/models/multi_modal/image_captioning_model.py index 79ab2b5f..0154ac29 100644 --- a/modelscope/models/multi_modal/image_captioning_model.py +++ b/modelscope/models/multi_modal/image_captioning_model.py @@ -1,6 +1,7 @@ import os.path as osp from typing import Any, Dict +import torch.cuda from PIL import Image from modelscope.metainfo import Models @@ -26,9 +27,13 @@ class OfaForImageCaptioning(Model): self.eval_caption = eval_caption tasks.register_task('caption', CaptionTask) - use_cuda = kwargs['use_cuda'] if 'use_cuda' in kwargs else False - use_fp16 = kwargs[ - 'use_fp16'] if 'use_fp16' in kwargs and use_cuda else False + if torch.cuda.is_available(): + self._device = torch.device('cuda') + else: + self._device = torch.device('cpu') + self.use_fp16 = kwargs[ + 'use_fp16'] if 'use_fp16' in kwargs and torch.cuda.is_available()\ + else False overrides = { 'bpe_dir': bpe_dir, 'eval_cider': False, @@ -39,13 +44,11 @@ class OfaForImageCaptioning(Model): } models, cfg, task = checkpoint_utils.load_model_ensemble_and_task( utils.split_paths(local_model), arg_overrides=overrides) - # Move models to GPU for model in models: model.eval() - if use_cuda: - model.cuda() - if use_fp16: + model.to(self._device) + if self.use_fp16: model.half() model.prepare_for_inference_(cfg) self.models = models @@ -68,6 +71,9 @@ class OfaForImageCaptioning(Model): self.task = task def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + import fairseq.utils + if torch.cuda.is_available(): + input = fairseq.utils.move_to_cuda(input, device=self._device) results, _ = self.eval_caption(self.task, self.generator, self.models, input) return {