mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 03:59:23 +01:00
[to #42322933] use gpu when available
ofa/caption 增加feature, 如果有gpu默认使用gpu
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9228113
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import os.path as osp
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch.cuda
|
||||
from PIL import Image
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
@@ -26,9 +27,13 @@ class OfaForImageCaptioning(Model):
|
||||
self.eval_caption = eval_caption
|
||||
|
||||
tasks.register_task('caption', CaptionTask)
|
||||
use_cuda = kwargs['use_cuda'] if 'use_cuda' in kwargs else False
|
||||
use_fp16 = kwargs[
|
||||
'use_fp16'] if 'use_fp16' in kwargs and use_cuda else False
|
||||
if torch.cuda.is_available():
|
||||
self._device = torch.device('cuda')
|
||||
else:
|
||||
self._device = torch.device('cpu')
|
||||
self.use_fp16 = kwargs[
|
||||
'use_fp16'] if 'use_fp16' in kwargs and torch.cuda.is_available()\
|
||||
else False
|
||||
overrides = {
|
||||
'bpe_dir': bpe_dir,
|
||||
'eval_cider': False,
|
||||
@@ -39,13 +44,11 @@ class OfaForImageCaptioning(Model):
|
||||
}
|
||||
models, cfg, task = checkpoint_utils.load_model_ensemble_and_task(
|
||||
utils.split_paths(local_model), arg_overrides=overrides)
|
||||
|
||||
# Move models to GPU
|
||||
for model in models:
|
||||
model.eval()
|
||||
if use_cuda:
|
||||
model.cuda()
|
||||
if use_fp16:
|
||||
model.to(self._device)
|
||||
if self.use_fp16:
|
||||
model.half()
|
||||
model.prepare_for_inference_(cfg)
|
||||
self.models = models
|
||||
@@ -68,6 +71,9 @@ class OfaForImageCaptioning(Model):
|
||||
self.task = task
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
import fairseq.utils
|
||||
if torch.cuda.is_available():
|
||||
input = fairseq.utils.move_to_cuda(input, device=self._device)
|
||||
results, _ = self.eval_caption(self.task, self.generator, self.models,
|
||||
input)
|
||||
return {
|
||||
|
||||
Reference in New Issue
Block a user