mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 03:59:23 +01:00
[to #42322933]update ofa caption model
将caption pipeline的实现从pipeline下沉到model,并拆解preprocessor
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9081211
* [to #41669377] docs and tools refinement and release
1. add build_doc linter script
2. add sphinx-docs support
3. add development doc and api doc
4. change version to 0.1.0 for the first internal release version
Link: https://code.aone.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8775307
* [to #41669377] add pipeline tutorial and fix bugs
1. add pipleine tutorial
2. fix bugs when using pipeline with certain model and preprocessor
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8814301
* refine doc
* refine doc
* upload ofa for caption(with source code but not whl)
* remove data in gitignore
* append uncommitted data dir in ofa
* remove ofa_dir , use ofa.whl instead.
* update BPE
* rollback changes used in debugging.
* Merge branch 'master' into ofa/image_caption
# Conflicts:
# docs/README.md
# docs/source/conf.py
# docs/source/index.rst
# docs/source/tutorials/pipeline.md
# maas_lib/models/nlp/sequence_classification_model.py
# maas_lib/pipelines/builder.py
# maas_lib/version.py
# setup.py
# tests/pipelines/test_text_classification.py
* 1. fix a bug in pipelines/builder.py.
2. modify model_path to model in image_captioning.py.
* 1. rename test_image_captioning.py.
* format all files using pre-commit.
* add fairseq in requirements.txt
* add fairseq in requirements.txt
* change fairseq path to git repo to a whl on oss in ofa.txt.
* change module_name to 'ofa'
* Merge remote-tracking branch 'origin/master' into ofa/image_caption
# Conflicts:
# maas_lib/pipelines/builder.py
* optim requirements for ofa / refine image_captioning.py
* uncommited change.
* feat: Fix confilct, auto commit by WebIDE
* Merge remote-tracking branch 'origin/master' into ofa/image_caption
# Conflicts:
# maas_lib/pipelines/multi_modal/__init__.py
# modelscope/pipelines/multi_modal/image_captioning.py
# tests/pipelines/test_image_captioning.py
* merge master
* merge master
* merge master
* rename
* Merge remote-tracking branch 'origin/master' into ofa/nlu
* add caption model
* Merge remote-tracking branch 'origin/master' into ofa/nlu
* update ofa caption model
* fix some typo, update unittest
* use local test image
* use local test image
* refactor, ofa -> multi_model
* merge master
* 删除 image_caption_pipeline.py
This commit is contained in:
3
data/test/images/image_captioning.png
Normal file
3
data/test/images/image_captioning.png
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:af83a94899a6d23339c3ecc5c4c58c57c835af57b531a2f4c50461184f820141
|
||||
size 603621
|
||||
@@ -4,4 +4,5 @@ from .audio.tts.am import SambertNetHifi16k
|
||||
from .audio.tts.vocoder import Hifigan16k
|
||||
from .base import Model
|
||||
from .builder import MODELS, build_model
|
||||
from .multi_model import OfaForImageCaptioning
|
||||
from .nlp import BertForSequenceClassification, SbertForSentenceSimilarity
|
||||
|
||||
1
modelscope/models/multi_model/__init__.py
Normal file
1
modelscope/models/multi_model/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .image_captioning_model import OfaForImageCaptioning
|
||||
80
modelscope/models/multi_model/image_captioning_model.py
Normal file
80
modelscope/models/multi_model/image_captioning_model.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import os.path as osp
|
||||
from typing import Any, Dict
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from ..base import Model
|
||||
from ..builder import MODELS
|
||||
|
||||
__all__ = ['OfaForImageCaptioning']
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.image_captioning, module_name=r'ofa-image-captioning')
|
||||
class OfaForImageCaptioning(Model):
|
||||
|
||||
def __init__(self, model_dir, *args, **kwargs):
|
||||
super().__init__(model_dir=model_dir, *args, **kwargs)
|
||||
ckpt_name = ModelFile.TORCH_MODEL_FILE
|
||||
local_model = osp.join(model_dir, ckpt_name)
|
||||
bpe_dir = model_dir
|
||||
# turn on cuda if GPU is available
|
||||
from fairseq import checkpoint_utils, tasks, utils
|
||||
from ofa.tasks.mm_tasks import CaptionTask
|
||||
from ofa.utils.eval_utils import eval_caption
|
||||
self.eval_caption = eval_caption
|
||||
|
||||
tasks.register_task('caption', CaptionTask)
|
||||
use_cuda = kwargs['use_cuda'] if 'use_cuda' in kwargs else False
|
||||
use_fp16 = kwargs[
|
||||
'use_fp16'] if 'use_fp16' in kwargs and use_cuda else False
|
||||
overrides = {
|
||||
'bpe_dir': bpe_dir,
|
||||
'eval_cider': False,
|
||||
'beam': 5,
|
||||
'max_len_b': 16,
|
||||
'no_repeat_ngram_size': 3,
|
||||
'seed': 7
|
||||
}
|
||||
models, cfg, task = checkpoint_utils.load_model_ensemble_and_task(
|
||||
utils.split_paths(local_model), arg_overrides=overrides)
|
||||
|
||||
# Move models to GPU
|
||||
for model in models:
|
||||
model.eval()
|
||||
if use_cuda:
|
||||
model.cuda()
|
||||
if use_fp16:
|
||||
model.half()
|
||||
model.prepare_for_inference_(cfg)
|
||||
self.models = models
|
||||
# Initialize generator
|
||||
self.generator = task.build_generator(models, cfg.generation)
|
||||
|
||||
# Initialize transform
|
||||
from torchvision import transforms
|
||||
mean = [0.5, 0.5, 0.5]
|
||||
std = [0.5, 0.5, 0.5]
|
||||
|
||||
self.patch_resize_transform = transforms.Compose([
|
||||
lambda image: image.convert('RGB'),
|
||||
transforms.Resize(
|
||||
(cfg.task.patch_image_size, cfg.task.patch_image_size),
|
||||
interpolation=Image.BICUBIC),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=mean, std=std),
|
||||
])
|
||||
self.task = task
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
results, _ = self.eval_caption(self.task, self.generator, self.models,
|
||||
input)
|
||||
return {
|
||||
'image_id': results[0]['image_id'],
|
||||
'caption': results[0]['caption']
|
||||
}
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# What should we do here ?
|
||||
return inputs
|
||||
@@ -24,7 +24,7 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
('bert-sentiment-analysis', 'damo/bert-base-sst2'),
|
||||
Tasks.text_generation: ('palm2.0',
|
||||
'damo/nlp_palm2.0_text-generation_chinese-base'),
|
||||
Tasks.image_captioning: ('ofa', None),
|
||||
Tasks.image_captioning: ('ofa', 'damo/ofa_image-caption_coco_large_en'),
|
||||
Tasks.image_generation:
|
||||
('person-image-cartoon',
|
||||
'damo/cv_unet_person-image-cartoon_compound-models'),
|
||||
|
||||
@@ -1 +1 @@
|
||||
from .image_caption_pipeline import ImageCaptionPipeline
|
||||
from .image_captioning_pipeline import ImageCaptionPipeline
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from modelscope.preprocessors import OfaImageCaptionPreprocessor, Preprocessor
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
from ..base import Model, Pipeline
|
||||
from ..builder import PIPELINES
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@PIPELINES.register_module(Tasks.image_captioning, module_name='ofa')
|
||||
class ImageCaptionPipeline(Pipeline):
|
||||
|
||||
def __init__(self,
|
||||
model: Union[Model, str],
|
||||
preprocessor: [Preprocessor] = None,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
assert isinstance(model, str) or isinstance(model, Model), \
|
||||
'model must be a single str or OfaForImageCaptioning'
|
||||
if isinstance(model, str):
|
||||
pipe_model = Model.from_pretrained(model)
|
||||
elif isinstance(model, Model):
|
||||
pipe_model = model
|
||||
else:
|
||||
raise NotImplementedError
|
||||
if preprocessor is None and pipe_model:
|
||||
preprocessor = OfaImageCaptionPreprocessor(model_dir=model)
|
||||
super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs)
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return inputs
|
||||
@@ -5,5 +5,6 @@ from .base import Preprocessor
|
||||
from .builder import PREPROCESSORS, build_preprocessor
|
||||
from .common import Compose
|
||||
from .image import LoadImage, load_image
|
||||
from .multi_model import OfaImageCaptionPreprocessor
|
||||
from .nlp import * # noqa F403
|
||||
from .text_to_speech import * # noqa F403
|
||||
|
||||
@@ -1,32 +1,50 @@
|
||||
from typing import Any, Dict
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os.path as osp
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from maas_hub.snapshot_download import snapshot_download
|
||||
from PIL import Image
|
||||
|
||||
from modelscope.pipelines.base import Input
|
||||
from modelscope.preprocessors import load_image
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
from ..base import Pipeline
|
||||
from ..builder import PIPELINES
|
||||
from modelscope.utils.constant import Fields, ModelFile
|
||||
from modelscope.utils.hub import get_model_cache_dir
|
||||
from modelscope.utils.type_assert import type_assert
|
||||
from .base import Preprocessor
|
||||
from .builder import PREPROCESSORS
|
||||
from .image import load_image
|
||||
|
||||
logger = get_logger()
|
||||
__all__ = [
|
||||
'OfaImageCaptionPreprocessor',
|
||||
]
|
||||
|
||||
|
||||
@PIPELINES.register_module(Tasks.image_captioning, module_name='ofa')
|
||||
class ImageCaptionPipeline(Pipeline):
|
||||
# TODO: refine using modelhub
|
||||
def __init__(self, model: str, bpe_dir: str):
|
||||
super().__init__()
|
||||
# turn on cuda if GPU is available
|
||||
@PREPROCESSORS.register_module(
|
||||
Fields.multi_modal, module_name=r'ofa-image-caption')
|
||||
class OfaImageCaptionPreprocessor(Preprocessor):
|
||||
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
"""preprocess the data via the vocab.txt from the `model_dir` path
|
||||
|
||||
Args:
|
||||
model_dir (str): model path
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if osp.exists(model_dir):
|
||||
local_model_dir = model_dir
|
||||
else:
|
||||
cache_path = get_model_cache_dir(model_dir)
|
||||
local_model_dir = cache_path if osp.exists(
|
||||
cache_path) else snapshot_download(model_dir)
|
||||
local_model = osp.join(local_model_dir, ModelFile.TORCH_MODEL_FILE)
|
||||
bpe_dir = local_model_dir
|
||||
|
||||
from fairseq import checkpoint_utils, tasks, utils
|
||||
from ofa.tasks.mm_tasks import CaptionTask
|
||||
|
||||
tasks.register_task('caption', CaptionTask)
|
||||
use_cuda = False
|
||||
# use fp16 only when GPU is available
|
||||
use_fp16 = False
|
||||
|
||||
overrides = {
|
||||
'bpe_dir': bpe_dir,
|
||||
'eval_cider': False,
|
||||
@@ -35,21 +53,9 @@ class ImageCaptionPipeline(Pipeline):
|
||||
'no_repeat_ngram_size': 3,
|
||||
'seed': 7
|
||||
}
|
||||
models, cfg, task = checkpoint_utils.load_model_ensemble_and_task(
|
||||
utils.split_paths(model), arg_overrides=overrides)
|
||||
|
||||
# Move models to GPU
|
||||
for model in models:
|
||||
model.eval()
|
||||
if use_cuda:
|
||||
model.cuda()
|
||||
if use_fp16:
|
||||
model.half()
|
||||
model.prepare_for_inference_(cfg)
|
||||
self.models = models
|
||||
# Initialize generator
|
||||
self.generator = task.build_generator(models, cfg.generation)
|
||||
|
||||
model, cfg, task = checkpoint_utils.load_model_ensemble_and_task(
|
||||
utils.split_paths(local_model), arg_overrides=overrides)
|
||||
del model
|
||||
# Initialize transform
|
||||
from torchvision import transforms
|
||||
mean = [0.5, 0.5, 0.5]
|
||||
@@ -69,7 +75,8 @@ class ImageCaptionPipeline(Pipeline):
|
||||
self.eos_item = torch.LongTensor([task.src_dict.eos()])
|
||||
self.pad_idx = task.src_dict.pad()
|
||||
|
||||
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||
@type_assert(object, (str, tuple))
|
||||
def __call__(self, data: Union[str, tuple]) -> Dict[str, Any]:
|
||||
|
||||
def encode_text(text, length=None, append_bos=False, append_eos=False):
|
||||
s = self.task.tgt_dict.encode_line(
|
||||
@@ -88,7 +95,7 @@ class ImageCaptionPipeline(Pipeline):
|
||||
patch_image = self.patch_resize_transform(input).unsqueeze(0)
|
||||
else:
|
||||
patch_image = self.patch_resize_transform(
|
||||
load_image(input)).unsqueeze(0)
|
||||
load_image(data)).unsqueeze(0)
|
||||
patch_mask = torch.tensor([True])
|
||||
text = 'what does the image describe?'
|
||||
src_text = encode_text(
|
||||
@@ -105,17 +112,3 @@ class ImageCaptionPipeline(Pipeline):
|
||||
}
|
||||
}
|
||||
return sample
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
from ofa.utils.eval_utils import eval_caption
|
||||
|
||||
results, _ = eval_caption(self.task, self.generator, self.models,
|
||||
input)
|
||||
return {
|
||||
'image_id': results[0]['image_id'],
|
||||
'caption': results[0]['caption']
|
||||
}
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# What should we do here ?
|
||||
return inputs
|
||||
@@ -1,10 +1,7 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from modelscope.fileio import File
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
@@ -12,23 +9,13 @@ from modelscope.utils.test_utils import test_level
|
||||
|
||||
class ImageCaptionTest(unittest.TestCase):
|
||||
|
||||
@unittest.skip('skip before model is restored in model hub')
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_run(self):
|
||||
model = 'https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_large_best_clean.pt'
|
||||
|
||||
os.system(
|
||||
'wget https://jirenmr.oss-cn-zhangjiakou.aliyuncs.com/ofa/BPE.zip'
|
||||
)
|
||||
os.system('unzip BPE.zip')
|
||||
bpe_dir = './BPE'
|
||||
|
||||
with tempfile.NamedTemporaryFile('wb', suffix='.pb') as ofile:
|
||||
ofile.write(File.read(model))
|
||||
img_captioning = pipeline(
|
||||
Tasks.image_captioning, model=ofile.name, bpe_dir=bpe_dir)
|
||||
|
||||
result = img_captioning('data/test/images/image_matting.png')
|
||||
print(result['caption'])
|
||||
img_captioning = pipeline(
|
||||
Tasks.image_captioning,
|
||||
model='damo/ofa_image-caption_coco_large_en')
|
||||
result = img_captioning('data/test/images/image_captioning.png')
|
||||
print(result['caption'])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user