From a2cf3d619e66cd68100240e42254f39acf4c6992 Mon Sep 17 00:00:00 2001 From: "yichang.zyc" Date: Tue, 21 Jun 2022 11:48:49 +0800 Subject: [PATCH] [to #42322933]update ofa caption model MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 将caption pipeline的实现从pipeline下沉到model,并拆解preprocessor Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9081211 * [to #41669377] docs and tools refinement and release 1. add build_doc linter script 2. add sphinx-docs support 3. add development doc and api doc 4. change version to 0.1.0 for the first internal release version Link: https://code.aone.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8775307 * [to #41669377] add pipeline tutorial and fix bugs 1. add pipleine tutorial 2. fix bugs when using pipeline with certain model and preprocessor Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8814301 * refine doc * refine doc * upload ofa for caption(with source code but not whl) * remove data in gitignore * append uncommitted data dir in ofa * remove ofa_dir , use ofa.whl instead. * update BPE * rollback changes used in debugging. * Merge branch 'master' into ofa/image_caption # Conflicts: # docs/README.md # docs/source/conf.py # docs/source/index.rst # docs/source/tutorials/pipeline.md # maas_lib/models/nlp/sequence_classification_model.py # maas_lib/pipelines/builder.py # maas_lib/version.py # setup.py # tests/pipelines/test_text_classification.py * 1. fix a bug in pipelines/builder.py. 2. modify model_path to model in image_captioning.py. * 1. rename test_image_captioning.py. * format all files using pre-commit. * add fairseq in requirements.txt * add fairseq in requirements.txt * change fairseq path to git repo to a whl on oss in ofa.txt. * change module_name to 'ofa' * Merge remote-tracking branch 'origin/master' into ofa/image_caption # Conflicts: # maas_lib/pipelines/builder.py * optim requirements for ofa / refine image_captioning.py * uncommited change. * feat: Fix confilct, auto commit by WebIDE * Merge remote-tracking branch 'origin/master' into ofa/image_caption # Conflicts: # maas_lib/pipelines/multi_modal/__init__.py # modelscope/pipelines/multi_modal/image_captioning.py # tests/pipelines/test_image_captioning.py * merge master * merge master * merge master * rename * Merge remote-tracking branch 'origin/master' into ofa/nlu * add caption model * Merge remote-tracking branch 'origin/master' into ofa/nlu * update ofa caption model * fix some typo, update unittest * use local test image * use local test image * refactor, ofa -> multi_model * merge master * 删除 image_caption_pipeline.py --- data/test/images/image_captioning.png | 3 + modelscope/models/__init__.py | 1 + modelscope/models/multi_model/__init__.py | 1 + .../multi_model/image_captioning_model.py | 80 +++++++++++++++++ modelscope/pipelines/builder.py | 2 +- modelscope/pipelines/multi_modal/__init__.py | 2 +- .../multi_modal/image_captioning_pipeline.py | 33 +++++++ modelscope/preprocessors/__init__.py | 1 + .../multi_model.py} | 89 +++++++++---------- tests/pipelines/test_image_captioning.py | 25 ++---- 10 files changed, 168 insertions(+), 69 deletions(-) create mode 100644 data/test/images/image_captioning.png create mode 100644 modelscope/models/multi_model/__init__.py create mode 100644 modelscope/models/multi_model/image_captioning_model.py create mode 100644 modelscope/pipelines/multi_modal/image_captioning_pipeline.py rename modelscope/{pipelines/multi_modal/image_caption_pipeline.py => preprocessors/multi_model.py} (57%) diff --git a/data/test/images/image_captioning.png b/data/test/images/image_captioning.png new file mode 100644 index 00000000..de3f1918 --- /dev/null +++ b/data/test/images/image_captioning.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:af83a94899a6d23339c3ecc5c4c58c57c835af57b531a2f4c50461184f820141 +size 603621 diff --git a/modelscope/models/__init__.py b/modelscope/models/__init__.py index 7d70e6ca..f873dcca 100644 --- a/modelscope/models/__init__.py +++ b/modelscope/models/__init__.py @@ -4,4 +4,5 @@ from .audio.tts.am import SambertNetHifi16k from .audio.tts.vocoder import Hifigan16k from .base import Model from .builder import MODELS, build_model +from .multi_model import OfaForImageCaptioning from .nlp import BertForSequenceClassification, SbertForSentenceSimilarity diff --git a/modelscope/models/multi_model/__init__.py b/modelscope/models/multi_model/__init__.py new file mode 100644 index 00000000..02e8d6ab --- /dev/null +++ b/modelscope/models/multi_model/__init__.py @@ -0,0 +1 @@ +from .image_captioning_model import OfaForImageCaptioning diff --git a/modelscope/models/multi_model/image_captioning_model.py b/modelscope/models/multi_model/image_captioning_model.py new file mode 100644 index 00000000..fad0663e --- /dev/null +++ b/modelscope/models/multi_model/image_captioning_model.py @@ -0,0 +1,80 @@ +import os.path as osp +from typing import Any, Dict + +from PIL import Image + +from modelscope.utils.constant import ModelFile, Tasks +from ..base import Model +from ..builder import MODELS + +__all__ = ['OfaForImageCaptioning'] + + +@MODELS.register_module( + Tasks.image_captioning, module_name=r'ofa-image-captioning') +class OfaForImageCaptioning(Model): + + def __init__(self, model_dir, *args, **kwargs): + super().__init__(model_dir=model_dir, *args, **kwargs) + ckpt_name = ModelFile.TORCH_MODEL_FILE + local_model = osp.join(model_dir, ckpt_name) + bpe_dir = model_dir + # turn on cuda if GPU is available + from fairseq import checkpoint_utils, tasks, utils + from ofa.tasks.mm_tasks import CaptionTask + from ofa.utils.eval_utils import eval_caption + self.eval_caption = eval_caption + + tasks.register_task('caption', CaptionTask) + use_cuda = kwargs['use_cuda'] if 'use_cuda' in kwargs else False + use_fp16 = kwargs[ + 'use_fp16'] if 'use_fp16' in kwargs and use_cuda else False + overrides = { + 'bpe_dir': bpe_dir, + 'eval_cider': False, + 'beam': 5, + 'max_len_b': 16, + 'no_repeat_ngram_size': 3, + 'seed': 7 + } + models, cfg, task = checkpoint_utils.load_model_ensemble_and_task( + utils.split_paths(local_model), arg_overrides=overrides) + + # Move models to GPU + for model in models: + model.eval() + if use_cuda: + model.cuda() + if use_fp16: + model.half() + model.prepare_for_inference_(cfg) + self.models = models + # Initialize generator + self.generator = task.build_generator(models, cfg.generation) + + # Initialize transform + from torchvision import transforms + mean = [0.5, 0.5, 0.5] + std = [0.5, 0.5, 0.5] + + self.patch_resize_transform = transforms.Compose([ + lambda image: image.convert('RGB'), + transforms.Resize( + (cfg.task.patch_image_size, cfg.task.patch_image_size), + interpolation=Image.BICUBIC), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std), + ]) + self.task = task + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + results, _ = self.eval_caption(self.task, self.generator, self.models, + input) + return { + 'image_id': results[0]['image_id'], + 'caption': results[0]['caption'] + } + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + # What should we do here ? + return inputs diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 6e2c791d..8897cf31 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -24,7 +24,7 @@ DEFAULT_MODEL_FOR_PIPELINE = { ('bert-sentiment-analysis', 'damo/bert-base-sst2'), Tasks.text_generation: ('palm2.0', 'damo/nlp_palm2.0_text-generation_chinese-base'), - Tasks.image_captioning: ('ofa', None), + Tasks.image_captioning: ('ofa', 'damo/ofa_image-caption_coco_large_en'), Tasks.image_generation: ('person-image-cartoon', 'damo/cv_unet_person-image-cartoon_compound-models'), diff --git a/modelscope/pipelines/multi_modal/__init__.py b/modelscope/pipelines/multi_modal/__init__.py index b1ee121c..b7402b93 100644 --- a/modelscope/pipelines/multi_modal/__init__.py +++ b/modelscope/pipelines/multi_modal/__init__.py @@ -1 +1 @@ -from .image_caption_pipeline import ImageCaptionPipeline +from .image_captioning_pipeline import ImageCaptionPipeline diff --git a/modelscope/pipelines/multi_modal/image_captioning_pipeline.py b/modelscope/pipelines/multi_modal/image_captioning_pipeline.py new file mode 100644 index 00000000..f0b1f53c --- /dev/null +++ b/modelscope/pipelines/multi_modal/image_captioning_pipeline.py @@ -0,0 +1,33 @@ +from typing import Any, Dict, Union + +from modelscope.preprocessors import OfaImageCaptionPreprocessor, Preprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger +from ..base import Model, Pipeline +from ..builder import PIPELINES + +logger = get_logger() + + +@PIPELINES.register_module(Tasks.image_captioning, module_name='ofa') +class ImageCaptionPipeline(Pipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: [Preprocessor] = None, + **kwargs): + super().__init__() + assert isinstance(model, str) or isinstance(model, Model), \ + 'model must be a single str or OfaForImageCaptioning' + if isinstance(model, str): + pipe_model = Model.from_pretrained(model) + elif isinstance(model, Model): + pipe_model = model + else: + raise NotImplementedError + if preprocessor is None and pipe_model: + preprocessor = OfaImageCaptionPreprocessor(model_dir=model) + super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/preprocessors/__init__.py b/modelscope/preprocessors/__init__.py index e3ae4c40..50860514 100644 --- a/modelscope/preprocessors/__init__.py +++ b/modelscope/preprocessors/__init__.py @@ -5,5 +5,6 @@ from .base import Preprocessor from .builder import PREPROCESSORS, build_preprocessor from .common import Compose from .image import LoadImage, load_image +from .multi_model import OfaImageCaptionPreprocessor from .nlp import * # noqa F403 from .text_to_speech import * # noqa F403 diff --git a/modelscope/pipelines/multi_modal/image_caption_pipeline.py b/modelscope/preprocessors/multi_model.py similarity index 57% rename from modelscope/pipelines/multi_modal/image_caption_pipeline.py rename to modelscope/preprocessors/multi_model.py index 3e5f49d0..de211611 100644 --- a/modelscope/pipelines/multi_modal/image_caption_pipeline.py +++ b/modelscope/preprocessors/multi_model.py @@ -1,32 +1,50 @@ -from typing import Any, Dict +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +from typing import Any, Dict, Union import numpy as np import torch +from maas_hub.snapshot_download import snapshot_download from PIL import Image -from modelscope.pipelines.base import Input -from modelscope.preprocessors import load_image -from modelscope.utils.constant import Tasks -from modelscope.utils.logger import get_logger -from ..base import Pipeline -from ..builder import PIPELINES +from modelscope.utils.constant import Fields, ModelFile +from modelscope.utils.hub import get_model_cache_dir +from modelscope.utils.type_assert import type_assert +from .base import Preprocessor +from .builder import PREPROCESSORS +from .image import load_image -logger = get_logger() +__all__ = [ + 'OfaImageCaptionPreprocessor', +] -@PIPELINES.register_module(Tasks.image_captioning, module_name='ofa') -class ImageCaptionPipeline(Pipeline): - # TODO: refine using modelhub - def __init__(self, model: str, bpe_dir: str): - super().__init__() - # turn on cuda if GPU is available +@PREPROCESSORS.register_module( + Fields.multi_modal, module_name=r'ofa-image-caption') +class OfaImageCaptionPreprocessor(Preprocessor): + + def __init__(self, model_dir: str, *args, **kwargs): + """preprocess the data via the vocab.txt from the `model_dir` path + + Args: + model_dir (str): model path + """ + super().__init__(*args, **kwargs) + + if osp.exists(model_dir): + local_model_dir = model_dir + else: + cache_path = get_model_cache_dir(model_dir) + local_model_dir = cache_path if osp.exists( + cache_path) else snapshot_download(model_dir) + local_model = osp.join(local_model_dir, ModelFile.TORCH_MODEL_FILE) + bpe_dir = local_model_dir + from fairseq import checkpoint_utils, tasks, utils from ofa.tasks.mm_tasks import CaptionTask tasks.register_task('caption', CaptionTask) - use_cuda = False - # use fp16 only when GPU is available - use_fp16 = False + overrides = { 'bpe_dir': bpe_dir, 'eval_cider': False, @@ -35,21 +53,9 @@ class ImageCaptionPipeline(Pipeline): 'no_repeat_ngram_size': 3, 'seed': 7 } - models, cfg, task = checkpoint_utils.load_model_ensemble_and_task( - utils.split_paths(model), arg_overrides=overrides) - - # Move models to GPU - for model in models: - model.eval() - if use_cuda: - model.cuda() - if use_fp16: - model.half() - model.prepare_for_inference_(cfg) - self.models = models - # Initialize generator - self.generator = task.build_generator(models, cfg.generation) - + model, cfg, task = checkpoint_utils.load_model_ensemble_and_task( + utils.split_paths(local_model), arg_overrides=overrides) + del model # Initialize transform from torchvision import transforms mean = [0.5, 0.5, 0.5] @@ -69,7 +75,8 @@ class ImageCaptionPipeline(Pipeline): self.eos_item = torch.LongTensor([task.src_dict.eos()]) self.pad_idx = task.src_dict.pad() - def preprocess(self, input: Input) -> Dict[str, Any]: + @type_assert(object, (str, tuple)) + def __call__(self, data: Union[str, tuple]) -> Dict[str, Any]: def encode_text(text, length=None, append_bos=False, append_eos=False): s = self.task.tgt_dict.encode_line( @@ -88,7 +95,7 @@ class ImageCaptionPipeline(Pipeline): patch_image = self.patch_resize_transform(input).unsqueeze(0) else: patch_image = self.patch_resize_transform( - load_image(input)).unsqueeze(0) + load_image(data)).unsqueeze(0) patch_mask = torch.tensor([True]) text = 'what does the image describe?' src_text = encode_text( @@ -105,17 +112,3 @@ class ImageCaptionPipeline(Pipeline): } } return sample - - def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: - from ofa.utils.eval_utils import eval_caption - - results, _ = eval_caption(self.task, self.generator, self.models, - input) - return { - 'image_id': results[0]['image_id'], - 'caption': results[0]['caption'] - } - - def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: - # What should we do here ? - return inputs diff --git a/tests/pipelines/test_image_captioning.py b/tests/pipelines/test_image_captioning.py index 74a65806..5fa6ff49 100644 --- a/tests/pipelines/test_image_captioning.py +++ b/tests/pipelines/test_image_captioning.py @@ -1,10 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -import os -import tempfile import unittest -from modelscope.fileio import File from modelscope.pipelines import pipeline from modelscope.utils.constant import Tasks from modelscope.utils.test_utils import test_level @@ -12,23 +9,13 @@ from modelscope.utils.test_utils import test_level class ImageCaptionTest(unittest.TestCase): - @unittest.skip('skip before model is restored in model hub') + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run(self): - model = 'https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_large_best_clean.pt' - - os.system( - 'wget https://jirenmr.oss-cn-zhangjiakou.aliyuncs.com/ofa/BPE.zip' - ) - os.system('unzip BPE.zip') - bpe_dir = './BPE' - - with tempfile.NamedTemporaryFile('wb', suffix='.pb') as ofile: - ofile.write(File.read(model)) - img_captioning = pipeline( - Tasks.image_captioning, model=ofile.name, bpe_dir=bpe_dir) - - result = img_captioning('data/test/images/image_matting.png') - print(result['caption']) + img_captioning = pipeline( + Tasks.image_captioning, + model='damo/ofa_image-caption_coco_large_en') + result = img_captioning('data/test/images/image_captioning.png') + print(result['caption']) if __name__ == '__main__':