From d1d2c96dd92aefc8df4bcde0c33e0796b2b07e2b Mon Sep 17 00:00:00 2001 From: "yichang.zyc" Date: Thu, 28 Jul 2022 23:25:17 +0800 Subject: [PATCH] add 6 ofa tasks --- data/test/images/image_classification.png | 3 + data/test/images/visual_grounding.png | 3 + .../test/images/visual_question_answering.png | 3 + modelscope/metainfo.py | 4 + modelscope/models/multi_modal/__init__.py | 3 +- .../multi_modal/image_captioning_model.py | 86 --- .../ofa/generate/sequence_generator.py | 47 +- .../models/multi_modal/ofa/utils/constant.py | 13 + .../models/multi_modal/ofa/utils/utils.py | 19 + .../models/multi_modal/ofa_for_all_tasks.py | 259 ++++++++ .../ofa_for_image_captioning_model.py | 53 -- modelscope/pipelines/cv/__init__.py | 3 +- .../cv/image_classification_pipeline.py | 38 +- modelscope/pipelines/multi_modal/__init__.py | 8 +- .../multi_modal/image_captioning_pipeline.py | 12 +- .../multi_modal/visual_entailment_pipeline.py | 42 ++ .../multi_modal/visual_grounding_pipeline.py | 42 ++ .../visual_question_answering_pipeline.py | 16 +- modelscope/pipelines/nlp/__init__.py | 4 + .../pipelines/nlp/summarization_pipeline.py | 42 ++ .../nlp/text_classification_pipeline.py | 42 ++ modelscope/preprocessors/__init__.py | 8 +- modelscope/preprocessors/multi_modal.py | 64 +- modelscope/preprocessors/ofa/__init__.py | 8 + modelscope/preprocessors/ofa/base.py | 117 ++++ .../preprocessors/ofa/image_captioning.py | 42 ++ .../preprocessors/ofa/image_classification.py | 43 ++ modelscope/preprocessors/ofa/summarization.py | 37 ++ .../preprocessors/ofa/text_classification.py | 38 ++ .../preprocessors/ofa/utils/__init__.py | 0 modelscope/preprocessors/ofa/utils/collate.py | 109 ++++ .../preprocessors/ofa/utils/random_help.py | 42 ++ .../preprocessors/ofa/utils/transforms.py | 557 ++++++++++++++++++ .../preprocessors/ofa/utils/vision_helper.py | 357 +++++++++++ .../preprocessors/ofa/visual_entailment.py | 62 ++ .../preprocessors/ofa/visual_grounding.py | 50 ++ .../ofa/visual_question_answering.py | 52 ++ modelscope/utils/constant.py | 1 + modelscope/utils/trie.py | 29 + tests/pipelines/test_image_captioning.py | 23 - tests/pipelines/test_ofa_tasks.py | 179 ++++++ 41 files changed, 2290 insertions(+), 270 deletions(-) create mode 100644 data/test/images/image_classification.png create mode 100644 data/test/images/visual_grounding.png create mode 100644 data/test/images/visual_question_answering.png delete mode 100644 modelscope/models/multi_modal/image_captioning_model.py create mode 100644 modelscope/models/multi_modal/ofa/utils/constant.py create mode 100644 modelscope/models/multi_modal/ofa/utils/utils.py create mode 100644 modelscope/models/multi_modal/ofa_for_all_tasks.py delete mode 100644 modelscope/models/multi_modal/ofa_for_image_captioning_model.py create mode 100644 modelscope/pipelines/multi_modal/visual_entailment_pipeline.py create mode 100644 modelscope/pipelines/multi_modal/visual_grounding_pipeline.py create mode 100644 modelscope/pipelines/nlp/summarization_pipeline.py create mode 100644 modelscope/pipelines/nlp/text_classification_pipeline.py create mode 100644 modelscope/preprocessors/ofa/__init__.py create mode 100644 modelscope/preprocessors/ofa/base.py create mode 100644 modelscope/preprocessors/ofa/image_captioning.py create mode 100644 modelscope/preprocessors/ofa/image_classification.py create mode 100644 modelscope/preprocessors/ofa/summarization.py create mode 100644 modelscope/preprocessors/ofa/text_classification.py create mode 100644 modelscope/preprocessors/ofa/utils/__init__.py create mode 100644 modelscope/preprocessors/ofa/utils/collate.py create mode 100644 modelscope/preprocessors/ofa/utils/random_help.py create mode 100644 modelscope/preprocessors/ofa/utils/transforms.py create mode 100644 modelscope/preprocessors/ofa/utils/vision_helper.py create mode 100644 modelscope/preprocessors/ofa/visual_entailment.py create mode 100644 modelscope/preprocessors/ofa/visual_grounding.py create mode 100644 modelscope/preprocessors/ofa/visual_question_answering.py create mode 100644 modelscope/utils/trie.py delete mode 100644 tests/pipelines/test_image_captioning.py create mode 100644 tests/pipelines/test_ofa_tasks.py diff --git a/data/test/images/image_classification.png b/data/test/images/image_classification.png new file mode 100644 index 00000000..3d1a2f8c --- /dev/null +++ b/data/test/images/image_classification.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8bdb9627c3a40897e84ee186b2a959f272790571644224e1d2efca443f867e12 +size 202823 diff --git a/data/test/images/visual_grounding.png b/data/test/images/visual_grounding.png new file mode 100644 index 00000000..a37791ec --- /dev/null +++ b/data/test/images/visual_grounding.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8b89734b9c9d89342e58fbe406d3b9bdc8e07447cb170a4ae2743000471fc969 +size 23069 diff --git a/data/test/images/visual_question_answering.png b/data/test/images/visual_question_answering.png new file mode 100644 index 00000000..e39d34a0 --- /dev/null +++ b/data/test/images/visual_question_answering.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d53e9fbdd129b234dcbec9b9fe6a15a0e05820e802a873f95955574267bbd2ff +size 121141 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 8ea9e7ed..20aa3586 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -69,6 +69,7 @@ class Pipelines(object): action_recognition = 'TAdaConv_action-recognition' animal_recognation = 'resnet101-animal_recog' cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding' + image_classification = 'image-classification' face_detection = 'resnet-face-detection-scrfd10gkps' live_category = 'live-category' general_image_classification = 'vit-base_image-classification_ImageNet-labels' @@ -92,6 +93,7 @@ class Pipelines(object): text_generation = 'text-generation' sentiment_analysis = 'sentiment-analysis' sentiment_classification = 'sentiment-classification' + text_classification = 'text-classification' fill_mask = 'fill-mask' csanmt_translation = 'csanmt-translation' nli = 'nli' @@ -113,6 +115,8 @@ class Pipelines(object): multi_modal_embedding = 'multi-modal-embedding' generative_multi_modal_embedding = 'generative-multi-modal-embedding' visual_question_answering = 'visual-question-answering' + visual_grounding = 'visual-grounding' + visual_entailment = 'visual-entailment' text_to_image_synthesis = 'text-to-image-synthesis' video_multi_modal_embedding = 'video-multi-modal-embedding' diff --git a/modelscope/models/multi_modal/__init__.py b/modelscope/models/multi_modal/__init__.py index 6e2864d1..8e6e2a39 100644 --- a/modelscope/models/multi_modal/__init__.py +++ b/modelscope/models/multi_modal/__init__.py @@ -11,7 +11,6 @@ if TYPE_CHECKING: from .mmr import VideoCLIPForMultiModalEmbedding from .mplug_for_visual_question_answering import \ MPlugForVisualQuestionAnswering - from .ofa_for_image_captioning_model import OfaForImageCaptioning else: _import_structure = { @@ -21,7 +20,7 @@ else: 'mmr': ['VideoCLIPForMultiModalEmbedding'], 'mplug_for_visual_question_answering': ['MPlugForVisualQuestionAnswering'], - 'ofa_for_image_captioning_model': ['OfaForImageCaptioning'] + 'ofa_for_all_tasks': ['OfaForAllTasks'] } import sys diff --git a/modelscope/models/multi_modal/image_captioning_model.py b/modelscope/models/multi_modal/image_captioning_model.py deleted file mode 100644 index 3e638a05..00000000 --- a/modelscope/models/multi_modal/image_captioning_model.py +++ /dev/null @@ -1,86 +0,0 @@ -import os.path as osp -from typing import Any, Dict - -import torch.cuda -from PIL import Image - -from modelscope.metainfo import Models -from modelscope.models.base import Model -from modelscope.models.builder import MODELS -from modelscope.utils.constant import ModelFile, Tasks - -__all__ = ['OfaForImageCaptioning'] - - -@MODELS.register_module(Tasks.image_captioning, module_name=Models.ofa) -class OfaForImageCaptioning(Model): - - def __init__(self, model_dir, *args, **kwargs): - super().__init__(model_dir=model_dir, *args, **kwargs) - ckpt_name = ModelFile.TORCH_MODEL_FILE - local_model = osp.join(model_dir, ckpt_name) - bpe_dir = model_dir - # turn on cuda if GPU is available - from fairseq import checkpoint_utils, tasks, utils - from ofa.tasks.mm_tasks import CaptionTask - from ofa.utils.eval_utils import eval_caption - self.eval_caption = eval_caption - tasks.register_task('caption', CaptionTask) - if torch.cuda.is_available(): - self._device = torch.device('cuda') - else: - self._device = torch.device('cpu') - self.use_fp16 = kwargs[ - 'use_fp16'] if 'use_fp16' in kwargs and torch.cuda.is_available()\ - else False - overrides = { - 'bpe_dir': bpe_dir, - 'eval_cider': False, - 'beam': 5, - 'max_len_b': 16, - 'no_repeat_ngram_size': 3, - 'seed': 7 - } - models, cfg, task = checkpoint_utils.load_model_ensemble_and_task( - utils.split_paths(local_model), arg_overrides=overrides) - # Move models to GPU - for model in models: - model.eval() - model.to(self._device) - if self.use_fp16: - model.half() - model.prepare_for_inference_(cfg) - self.models = models - # Initialize generator - self.generator = task.build_generator(models, cfg.generation) - - # Initialize transform - from torchvision import transforms - mean = [0.5, 0.5, 0.5] - std = [0.5, 0.5, 0.5] - - self.patch_resize_transform = transforms.Compose([ - lambda image: image.convert('RGB'), - transforms.Resize( - (cfg.task.patch_image_size, cfg.task.patch_image_size), - interpolation=Image.BICUBIC), - transforms.ToTensor(), - transforms.Normalize(mean=mean, std=std), - ]) - self.task = task - - def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: - import fairseq.utils - if torch.cuda.is_available(): - input = fairseq.utils.move_to_cuda(input, device=self._device) - results, _ = self.eval_caption(self.task, self.generator, self.models, - input) - from modelscope.outputs import OutputKeys - return { - 'image_id': results[0]['image_id'], - OutputKeys.CAPTION: results[0][OutputKeys.CAPTION] - } - - def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: - # What should we do here ? - return inputs diff --git a/modelscope/models/multi_modal/ofa/generate/sequence_generator.py b/modelscope/models/multi_modal/ofa/generate/sequence_generator.py index 548271de..590fb67b 100644 --- a/modelscope/models/multi_modal/ofa/generate/sequence_generator.py +++ b/modelscope/models/multi_modal/ofa/generate/sequence_generator.py @@ -194,13 +194,6 @@ class SequenceGenerator(nn.Module): bos_token: Optional[int] = None, ): model = EnsembleModel(models) - # incremental_states = torch.jit.annotate( - # List[Dict[str, Dict[str, Optional[Tensor]]]], - # [ - # torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {}) - # for i in range(model.models_size) - # ], - # ) incremental_states = torch.jit.annotate( List[Tuple[Tuple[torch.Tensor]]], [ @@ -208,8 +201,6 @@ class SequenceGenerator(nn.Module): for i in range(model.models_size) ], ) - # print("incremental_states",incremental_states) - # print("incremental_states[0]",incremental_states[0]) net_input = sample['net_input'] if 'src_tokens' in net_input: @@ -281,7 +272,6 @@ class SequenceGenerator(nn.Module): tokens = (torch.zeros(bsz * beam_size, max_len + 2).to(src_tokens).long().fill_( self.pad)) # +2 for eos and pad - # tokens[:, 0] = self.eos if bos_token is None else bos_token tokens[:, 0] = self.bos attn: Optional[Tensor] = None @@ -335,7 +325,7 @@ class SequenceGenerator(nn.Module): corr.unsqueeze(-1) * beam_size) original_batch_idxs = original_batch_idxs[batch_idxs] model.reorder_incremental_state(incremental_states, - reorder_state) # todo + reorder_state) encoder_outs = model.reorder_encoder_out( encoder_outs, reorder_state) @@ -479,7 +469,6 @@ class SequenceGenerator(nn.Module): batch_mask = torch.ones( bsz, dtype=torch.bool, device=cand_indices.device) batch_mask[finalized_sents] = False - # TODO replace `nonzero(as_tuple=False)` after TorchScript supports it batch_idxs = torch.arange( bsz, device=cand_indices.device).masked_select(batch_mask) @@ -833,7 +822,7 @@ class EnsembleModel(nn.Module): # decode each model if self.has_incremental_states(): - decoder_out = model.decoder.forward( # todo 模型输入不同 + decoder_out = model.decoder.forward( input_ids=tokens, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, @@ -846,7 +835,7 @@ class EnsembleModel(nn.Module): else: if hasattr(model, 'decoder'): # decoder_out = model.decoder.forward(tokens, code_masks=code_mask, encoder_out=encoder_out) - decoder_out = model.decoder.forward( # todo 模型输入不同 + decoder_out = model.decoder.forward( input_ids=tokens, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, @@ -855,32 +844,9 @@ class EnsembleModel(nn.Module): src_pos_embed=src_pos_embed) else: decoder_out = model.forward(tokens) - # print('#### decoder_out ####', decoder_out) - # print('#### decoder_out ####', decoder_out.keys()) - # for k,v in decoder_out.items(): - # print(k) - # if isinstance(v, Tensor): - # print(v.shape) - # elif k == "past_key_values": - # print(len(v)) - # print([v[0][i].shape for i in range(len(v[0]))]) - # else: - # print(len(v)) - # print([v[i].shape for i in range(len(v))]) attn: Optional[Tensor] = None decoder_len = len(decoder_out) - # if decoder_len > 1 and decoder_out[1] is not None: - # if isinstance(decoder_out[1], Tensor): - # attn = decoder_out[1] - # else: - # attn_holder = decoder_out[1]["attn"] - # if isinstance(attn_holder, Tensor): - # attn = attn_holder - # elif attn_holder is not None: - # attn = attn_holder[0] - # if attn is not None: - # attn = attn[:, -1, :] if 'cross_attentions' in decoder_out: attn = decoder_out['cross_attentions'][-1].transpose(1, 0) @@ -888,11 +854,6 @@ class EnsembleModel(nn.Module): if attn is not None: attn = attn[:, -1, :] - # decoder_out_tuple = ( - # decoder_out[0][:, -1:, :].div_(temperature), - # None if decoder_len <= 1 else decoder_out[1], - # ) - decoder_out_tuple = ( decoder_out[0][:, -1:, :].div_(temperature), None if decoder_len <= 1 else attn, @@ -993,5 +954,5 @@ class EnsembleModel(nn.Module): if not self.has_incremental_states(): return for i, model in enumerate(self.models): - model.decoder.reorder_incremental_state_scripting( # todo + model.decoder.reorder_incremental_state_scripting( incremental_states[i], new_order) diff --git a/modelscope/models/multi_modal/ofa/utils/constant.py b/modelscope/models/multi_modal/ofa/utils/constant.py new file mode 100644 index 00000000..984da443 --- /dev/null +++ b/modelscope/models/multi_modal/ofa/utils/constant.py @@ -0,0 +1,13 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from modelscope.outputs import OutputKeys +from modelscope.utils.constant import Tasks + +OFA_TASK_KEY_MAPPING = { + Tasks.image_captioning: OutputKeys.CAPTION, + Tasks.summarization: OutputKeys.TEXT, + Tasks.visual_question_answering: OutputKeys.TEXT, + Tasks.visual_grounding: OutputKeys.BOXES, + Tasks.text_classification: (OutputKeys.SCORES, OutputKeys.LABELS), + Tasks.image_classification: OutputKeys.LABELS, + Tasks.visual_entailment: (OutputKeys.SCORES, OutputKeys.LABELS), +} diff --git a/modelscope/models/multi_modal/ofa/utils/utils.py b/modelscope/models/multi_modal/ofa/utils/utils.py new file mode 100644 index 00000000..6d8943a1 --- /dev/null +++ b/modelscope/models/multi_modal/ofa/utils/utils.py @@ -0,0 +1,19 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Optional + +import torch + + +def expand_mask(mask: torch.Tensor, + dtype: torch.dtype, + tgt_len: Optional[int] = None): + r""" + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, + src_len).to(dtype) + return expanded_mask.masked_fill(expanded_mask.bool(), + torch.finfo(dtype).min) diff --git a/modelscope/models/multi_modal/ofa_for_all_tasks.py b/modelscope/models/multi_modal/ofa_for_all_tasks.py new file mode 100644 index 00000000..aaeccaf9 --- /dev/null +++ b/modelscope/models/multi_modal/ofa_for_all_tasks.py @@ -0,0 +1,259 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import math +from os import path as osp +from typing import Any, Dict + +import json +import torch.cuda +import torch.nn.functional as F + +from modelscope.metainfo import Models +from modelscope.models.base import Model, Tensor +from modelscope.models.builder import MODELS +from modelscope.outputs import OutputKeys +from modelscope.preprocessors.ofa.utils.collate import collate_tokens +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile +from modelscope.utils.trie import Trie +from .ofa import OFAModel, OFATokenizer +from .ofa.generate import sequence_generator as sg +from .ofa.generate.utils import move_to_device +from .ofa.utils.constant import OFA_TASK_KEY_MAPPING, Tasks +from .ofa.utils.utils import expand_mask + +__all__ = ['OfaForAllTasks'] + + +@MODELS.register_module(Tasks.image_captioning, module_name=Models.ofa) +@MODELS.register_module(Tasks.visual_grounding, module_name=Models.ofa) +@MODELS.register_module( + Tasks.visual_question_answering, module_name=Models.ofa) +@MODELS.register_module(Tasks.visual_entailment, module_name=Models.ofa) +@MODELS.register_module(Tasks.image_classification, module_name=Models.ofa) +@MODELS.register_module(Tasks.summarization, module_name=Models.ofa) +@MODELS.register_module(Tasks.text_classification, module_name=Models.ofa) +class OfaForAllTasks(Model): + + def __init__(self, model_dir, *args, **kwargs): + super().__init__(model_dir=model_dir, *args, **kwargs) + model = OFAModel.from_pretrained(model_dir) + self.cfg = Config.from_file( + osp.join(model_dir, ModelFile.CONFIGURATION)) + self.model = model.module if hasattr(model, 'module') else model + self.tokenizer = OFATokenizer.from_pretrained(model_dir) + self.tokenizer.add_tokens([''.format(i) for i in range(8192)]) + self.tokenizer.add_tokens([''.format(i) for i in range(1000)]) + self.cfg.update({'num_bins': 1000, 'num_codes': 8192}) + self.batch_size = self.cfg.model.get('batch_size', 1) + self.val_batch_size = self.cfg.model.get('valid_batch_size', + self.batch_size) + self.gen_type = self.cfg.model.get('gen_type', 'generation') + assert self.gen_type in ['generation', 'traverse'], \ + 'model.gen_type must be in ["generation", "traverse"]' + self._device = torch.device('cuda') if torch.cuda.is_available() \ + else torch.device('cpu') + self.eos_item = torch.LongTensor([self.tokenizer.eos_token_id + ]).to(self._device) + self.index2ans = {} + self.ans2label_dict = {} + self.load_ans2label() + # Initialize generator + sg_args = { + 'tokenizer': self.tokenizer, + 'beam_size': 5, + 'max_len_b': 16, + 'min_len': 1, + 'no_repeat_ngram_size': 3, + 'constraint_range': None + } + if hasattr(self.cfg.model, 'beam_search'): + sg_args.update(self.cfg.model.beam_search) + if len(self.ans2label_dict) > 0: + self.constraint_trie = Trie(self.tokenizer.eos_token_id) + self.val_ans_l = [] + self.val_masks_l = [] + self.build_trie() + sg_args['constraint_trie'] = self.constraint_trie + self.model.to(self._device) + self.generator = sg.SequenceGenerator(**sg_args) + inference_d = { + 'generation': self._text_gen_inference, + 'traverse': self._traverse_inference, + } + self.task_inference_mapping = { + Tasks.image_captioning: self._text_gen_inference, + Tasks.summarization: self._text_gen_inference, + Tasks.visual_grounding: self._visual_grounding_inference, + Tasks.visual_entailment: inference_d[self.gen_type], + Tasks.visual_question_answering: inference_d[self.gen_type], + Tasks.text_classification: inference_d[self.gen_type], + Tasks.image_classification: inference_d[self.gen_type], + } + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + ret = self.task_inference_mapping[self.cfg.task](input) + ret['samples'] = input['samples'] + for key in [ + OutputKeys.CAPTION, OutputKeys.TEXT, OutputKeys.BOXES, + OutputKeys.LABELS, OutputKeys.SCORES + ]: + if key in ret and len(ret[key]) == 1: + ret[key] = ret[key][0] + if key not in ret: + ret[key] = None + return ret + + def postprocess(self, input: Dict[str, Tensor], + **kwargs) -> Dict[str, Tensor]: + return input + + def _text_gen_inference(self, input): + input = move_to_device(input, self._device) + gen_output = self.generator.generate([self.model], input) + gen = [gen_output[i][0]['tokens'] for i in range(len(gen_output))] + result = self.tokenizer.batch_decode(gen, skip_special_tokens=True) + # text generation tasks have no score + ret = {OFA_TASK_KEY_MAPPING[self.cfg.task]: result} + if self.cfg.task.endswith('classification'): + ret[OutputKeys.SCORES] = [1.0] * len(result) + return ret + + def _visual_grounding_inference(self, input): + input = move_to_device(input, self._device) + gen_output = self.generator.generate([self.model], input) + tokens = [gen_output[i][0]['tokens'] for i in range(len(gen_output))] + region_coord_l = list() + for i in range(len(tokens)): + region_coord_l.append(tokens[i][:-1] + - len(self.tokenizer.get_vocab().items()) + + self.cfg.num_bins) + region_tensor = torch.stack(region_coord_l, dim=0) + region_tensor = region_tensor / ( + self.cfg.num_bins - 1) * self.cfg.model.get('max_image_size', 512) + region_tensor[:, ::2] /= input['w_resize_ratios'] + region_tensor[:, 1::2] /= input['h_resize_ratios'] + return { + OutputKeys.BOXES: move_to_device(region_tensor, + torch.device('cpu')), + OutputKeys.SCORES: [1.0] * region_tensor.shape[0] + } + + def _traverse_inference(self, input): + input = move_to_device(input, self._device) + encoder_input = dict() + for key in input['net_input'].keys(): + encoder_input[key] = input['net_input'][key] + encoder_out = self.model.encoder(**encoder_input) + valid_result = [] + for val_ans, val_masks in zip(self.val_ans_l, self.val_masks_l): + valid_size = len(val_ans) + valid_tgt_items = [ + torch.cat([ + torch.tensor(decoder_prompt[1:]), valid_answer, + self.eos_item + ]) for decoder_prompt in input['decoder_prompts'] + for valid_answer in val_ans + ] + valid_prev_items = [ + torch.cat([torch.tensor(decoder_prompt), valid_answer]) + for decoder_prompt in input['decoder_prompts'] + for valid_answer in val_ans + ] + valid_constraint_mask_items = [ + torch.cat([ + torch.zeros( + len(decoder_prompt) - 1, + valid_constraint_mask.size(1)).bool().to(self._device), + valid_constraint_mask], dim=0) # yapf: disable + for decoder_prompt in input['decoder_prompts'] # yapf: disable + for valid_constraint_mask in val_masks] # yapf: disable + valid_tgt = collate_tokens( + valid_tgt_items, + pad_idx=self.tokenizer.pad_token_id).to(self._device) + valid_prev_output = collate_tokens( + valid_prev_items, + pad_idx=self.tokenizer.pad_token_id).to(self._device) + val_masks = collate_tokens( + valid_constraint_mask_items, + pad_idx=self.tokenizer.pad_token_id).to(self._device) + new_encoder_out = { + 'last_hidden_state': + encoder_out['last_hidden_state'].repeat_interleave( + valid_size, dim=0), + 'padding_mask': + encoder_out['padding_mask'].repeat_interleave( + valid_size, dim=0), + 'position_embedding': + encoder_out['position_embedding'].repeat_interleave( + valid_size, dim=0) + } + encoder_attention_mask = expand_mask( + new_encoder_out['padding_mask'], + new_encoder_out['last_hidden_state'].dtype, + valid_prev_output.shape[-1]) + + decoder_out = self.model.decoder( + valid_prev_output, + encoder_hidden_states=new_encoder_out['last_hidden_state'], + encoder_attention_mask=encoder_attention_mask, + src_pos_embed=new_encoder_out['position_embedding']) + + decoder_out[0].masked_fill_(~val_masks, -math.inf) + lprobs = self.model.get_normalized_probs( + decoder_out, log_probs=True) + scores = lprobs.gather( + dim=-1, index=valid_tgt.unsqueeze(-1)).squeeze(-1) + scores = scores.masked_fill( + valid_tgt.eq(self.tokenizer.pad_token_id), 0) + scores = scores.masked_fill((~val_masks).all(2), 0) + scores = scores.sum(1) + scores = scores.view(-1, valid_size) + valid_result.append(scores) + valid_result = torch.cat(valid_result, dim=-1) + predicts = valid_result.argmax(1).tolist() + probs = F.softmax(valid_result, dim=-1) + hyps = [self.index2ans[predict_index] for predict_index in predicts] + scores = [ + float(prob[idx].cpu().detach().numpy()) + for prob, idx in zip(probs, predicts) + ] + return {OutputKeys.LABELS: hyps, OutputKeys.SCORES: scores} + + def build_trie(self): + answer_item_list = [] + + for i, answer in enumerate(self.ans2label_dict.keys()): + answer_item = self.tokenizer( + ' ' + answer, return_tensors='pt', + add_special_tokens=False).input_ids.squeeze(0) + answer_item_list.append(answer_item) + self.index2ans[i] = answer + self.constraint_trie.insert([self.tokenizer.bos_token_id] + + answer_item.tolist() + + [self.tokenizer.eos_token_id]) + + constraint_mask_list = [] + for answer_item in answer_item_list: + constraint_mask = torch.zeros( + (len(answer_item) + 1, + len(self.tokenizer.get_vocab()))).bool() + for i in range(len(answer_item) + 1): + constraint_prefix_token = [self.tokenizer.bos_token_id + ] + answer_item[:i].tolist() + constraint_nodes = self.constraint_trie.get_next_layer( + constraint_prefix_token) + constraint_mask[i][constraint_nodes] = True + constraint_mask_list.append(constraint_mask) + + for i in range(0, len(answer_item_list), self.val_batch_size): + self.val_ans_l += [answer_item_list[i:i + self.val_batch_size]] + self.val_masks_l += [ + constraint_mask_list[i:i + self.val_batch_size] + ] + self.val_ans_l = move_to_device(self.val_ans_l, self._device) + self.val_masks_l = move_to_device(self.val_masks_l, self._device) + + def load_ans2label(self): + if self.cfg.model.get('answer2label', None): + filename = osp.join(self.model_dir, self.cfg.model.answer2label) + self.ans2label_dict = json.load(open(filename)) diff --git a/modelscope/models/multi_modal/ofa_for_image_captioning_model.py b/modelscope/models/multi_modal/ofa_for_image_captioning_model.py deleted file mode 100644 index 5d646143..00000000 --- a/modelscope/models/multi_modal/ofa_for_image_captioning_model.py +++ /dev/null @@ -1,53 +0,0 @@ -from typing import Any, Dict - -import torch.cuda - -from modelscope.metainfo import Models -from modelscope.models.base import Model -from modelscope.models.builder import MODELS -from modelscope.outputs import OutputKeys -from modelscope.utils.constant import Tasks -from .ofa import OFAModel, OFATokenizer -from .ofa.generate import sequence_generator as sg -from .ofa.generate.utils import move_to_device - -__all__ = ['OfaForImageCaptioning'] - - -@MODELS.register_module(Tasks.image_captioning, module_name=Models.ofa) -class OfaForImageCaptioning(Model): - - def __init__(self, model_dir, *args, **kwargs): - super().__init__(model_dir=model_dir, *args, **kwargs) - model = OFAModel.from_pretrained(model_dir) - - self.model = model.module if hasattr(model, 'module') else model - self.tokenizer = OFATokenizer.from_pretrained(model_dir) - self.tokenizer.add_tokens([''.format(i) for i in range(8192)]) - self.tokenizer.add_tokens([''.format(i) for i in range(1000)]) - self._device = torch.device('cuda') if torch.cuda.is_available() \ - else torch.device('cpu') - self.model.to(self._device) - # Initialize generator - sg_args = { - 'tokenizer': self.tokenizer, - 'beam_size': 5, - 'max_len_b': 16, - 'min_len': 1, - 'no_repeat_ngram_size': 3, - 'constraint_range': None - } - if hasattr(kwargs, 'beam_search'): - sg_args.update(kwargs['beam_search']) - self.generator = sg.SequenceGenerator(**sg_args) - - def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: - input = move_to_device(input, self._device) - gen_output = self.generator.generate([self.model], input) - gen = [gen_output[i][0]['tokens'] for i in range(len(gen_output))] - result = self.tokenizer.batch_decode(gen, skip_special_tokens=True) - return {'image_id': '42', OutputKeys.CAPTION: result[0]} - - def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: - # What should we do here ? - return inputs diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index abfefcca..f8a8f1d1 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -24,6 +24,7 @@ if TYPE_CHECKING: from .ocr_detection_pipeline import OCRDetectionPipeline from .video_category_pipeline import VideoCategoryPipeline from .virtual_tryon_pipeline import VirtualTryonPipeline + from .image_classification_pipeline import ImageClassificationPipeline else: _import_structure = { 'action_recognition_pipeline': ['ActionRecognitionPipeline'], @@ -33,7 +34,7 @@ else: 'face_image_generation_pipeline': ['FaceImageGenerationPipeline'], 'face_recognition_pipeline': ['FaceRecognitionPipeline'], 'image_classification_pipeline': - ['GeneralImageClassificationPipeline'], + ['GeneralImageClassificationPipeline', 'ImageClassificationPipeline'], 'image_cartoon_pipeline': ['ImageCartoonPipeline'], 'image_denoise_pipeline': ['ImageDenoisePipeline'], 'image_color_enhance_pipeline': ['ImageColorEnhancePipeline'], diff --git a/modelscope/pipelines/cv/image_classification_pipeline.py b/modelscope/pipelines/cv/image_classification_pipeline.py index 169187fe..cf48de6b 100644 --- a/modelscope/pipelines/cv/image_classification_pipeline.py +++ b/modelscope/pipelines/cv/image_classification_pipeline.py @@ -1,4 +1,5 @@ -from typing import Any, Dict +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict, Union import cv2 import numpy as np @@ -7,16 +8,41 @@ import torch from modelscope.metainfo import Pipelines from modelscope.outputs import OutputKeys -from modelscope.pipelines.base import Input -from modelscope.preprocessors import load_image +from modelscope.pipelines.base import Input, Model, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import OfaPreprocessor, Preprocessor, load_image from modelscope.utils.constant import Tasks from modelscope.utils.logger import get_logger -from ..base import Pipeline -from ..builder import PIPELINES logger = get_logger() +@PIPELINES.register_module( + Tasks.image_classification, module_name=Pipelines.image_classification) +class ImageClassificationPipeline(Pipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: [Preprocessor] = None, + **kwargs): + super().__init__(model=model) + assert isinstance(model, str) or isinstance(model, Model), \ + 'model must be a single str or OfaForAllTasks' + if isinstance(model, str): + pipe_model = Model.from_pretrained(model) + elif isinstance(model, Model): + pipe_model = model + else: + raise NotImplementedError + pipe_model.model.eval() + if preprocessor is None and pipe_model: + preprocessor = OfaPreprocessor(model_dir=pipe_model.model_dir) + super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs + + @PIPELINES.register_module( Tasks.image_classification_imagenet, module_name=Pipelines.general_image_classification) @@ -27,7 +53,7 @@ class GeneralImageClassificationPipeline(Pipeline): def __init__(self, model: str, **kwargs): """ - use `model` and `preprocessor` to create a kws pipeline for prediction + use `model` and `preprocessor` to create a image classification pipeline for prediction Args: model: model id on modelscope hub. """ diff --git a/modelscope/pipelines/multi_modal/__init__.py b/modelscope/pipelines/multi_modal/__init__.py index 523a1002..55906e43 100644 --- a/modelscope/pipelines/multi_modal/__init__.py +++ b/modelscope/pipelines/multi_modal/__init__.py @@ -5,7 +5,9 @@ from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: from .generative_multi_modal_embedding_pipeline import GEMMMultiModalEmbeddingPipeline - from .image_captioning_pipeline import ImageCaptionPipeline + from .image_captioning_pipeline import ImageCaptioningPipeline + from .visual_entailment_pipeline import VisualEntailmentPipeline + from .visual_grounding_pipeline import VisualGroundingPipeline from .multi_modal_embedding_pipeline import MultiModalEmbeddingPipeline from .text_to_image_synthesis_pipeline import TextToImageSynthesisPipeline from .video_multi_modal_embedding_pipeline import \ @@ -14,7 +16,9 @@ if TYPE_CHECKING: else: _import_structure = { - 'image_captioning_pipeline': ['ImageCaptionPipeline'], + 'image_captioning_pipeline': ['ImageCaptioningPipeline'], + 'visual_entailment_pipeline': ['VisualEntailmentPipeline'], + 'visual_grounding_pipeline': ['VisualGroundingPipeline'], 'multi_modal_embedding_pipeline': ['MultiModalEmbeddingPipeline'], 'text_to_image_synthesis_pipeline': ['TextToImageSynthesisPipeline'], 'visual_question_answering_pipeline': diff --git a/modelscope/pipelines/multi_modal/image_captioning_pipeline.py b/modelscope/pipelines/multi_modal/image_captioning_pipeline.py index 90b1dec0..4d491ceb 100644 --- a/modelscope/pipelines/multi_modal/image_captioning_pipeline.py +++ b/modelscope/pipelines/multi_modal/image_captioning_pipeline.py @@ -1,9 +1,10 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. from typing import Any, Dict, Optional, Union from modelscope.metainfo import Pipelines from modelscope.pipelines.base import Model, Pipeline from modelscope.pipelines.builder import PIPELINES -from modelscope.preprocessors import OfaImageCaptionPreprocessor, Preprocessor +from modelscope.preprocessors import OfaPreprocessor, Preprocessor from modelscope.utils.constant import Tasks from modelscope.utils.logger import get_logger @@ -12,28 +13,29 @@ logger = get_logger() @PIPELINES.register_module( Tasks.image_captioning, module_name=Pipelines.image_captioning) -class ImageCaptionPipeline(Pipeline): +class ImageCaptioningPipeline(Pipeline): def __init__(self, model: Union[Model, str], preprocessor: Optional[Preprocessor] = None, **kwargs): """ - use `model` and `preprocessor` to create a kws pipeline for prediction + use `model` and `preprocessor` to create a image captioning pipeline for prediction Args: model: model id on modelscope hub. """ super().__init__(model=model) assert isinstance(model, str) or isinstance(model, Model), \ - 'model must be a single str or OfaForImageCaptioning' + 'model must be a single str or OfaForAllTasks' if isinstance(model, str): pipe_model = Model.from_pretrained(model) elif isinstance(model, Model): pipe_model = model else: raise NotImplementedError + pipe_model.model.eval() if preprocessor is None and pipe_model: - preprocessor = OfaImageCaptionPreprocessor(model_dir=model) + preprocessor = OfaPreprocessor(model_dir=pipe_model.model_dir) super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: diff --git a/modelscope/pipelines/multi_modal/visual_entailment_pipeline.py b/modelscope/pipelines/multi_modal/visual_entailment_pipeline.py new file mode 100644 index 00000000..e1bd3929 --- /dev/null +++ b/modelscope/pipelines/multi_modal/visual_entailment_pipeline.py @@ -0,0 +1,42 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict, Union + +from modelscope.metainfo import Pipelines +from modelscope.pipelines.base import Model, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import OfaPreprocessor, Preprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.visual_entailment, module_name=Pipelines.visual_entailment) +class VisualEntailmentPipeline(Pipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: [Preprocessor] = None, + **kwargs): + """ + use `model` and `preprocessor` to create a visual entailment pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model) + assert isinstance(model, str) or isinstance(model, Model), \ + 'model must be a single str or OfaForAllTasks' + if isinstance(model, str): + pipe_model = Model.from_pretrained(model) + elif isinstance(model, Model): + pipe_model = model + else: + raise NotImplementedError + pipe_model.model.eval() + if preprocessor is None and pipe_model: + preprocessor = OfaPreprocessor(model_dir=pipe_model.model_dir) + super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/multi_modal/visual_grounding_pipeline.py b/modelscope/pipelines/multi_modal/visual_grounding_pipeline.py new file mode 100644 index 00000000..a603d4fd --- /dev/null +++ b/modelscope/pipelines/multi_modal/visual_grounding_pipeline.py @@ -0,0 +1,42 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict, Union + +from modelscope.metainfo import Pipelines +from modelscope.pipelines.base import Model, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import OfaPreprocessor, Preprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.visual_grounding, module_name=Pipelines.visual_grounding) +class VisualGroundingPipeline(Pipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: [Preprocessor] = None, + **kwargs): + """ + use `model` and `preprocessor` to create a visual grounding pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model) + assert isinstance(model, str) or isinstance(model, Model), \ + 'model must be a single str or OfaForAllTasks' + if isinstance(model, str): + pipe_model = Model.from_pretrained(model) + elif isinstance(model, Model): + pipe_model = model + else: + raise NotImplementedError + pipe_model.model.eval() + if preprocessor is None and pipe_model: + preprocessor = OfaPreprocessor(model_dir=pipe_model.model_dir) + super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/multi_modal/visual_question_answering_pipeline.py b/modelscope/pipelines/multi_modal/visual_question_answering_pipeline.py index 0b1fedff..47727a29 100644 --- a/modelscope/pipelines/multi_modal/visual_question_answering_pipeline.py +++ b/modelscope/pipelines/multi_modal/visual_question_answering_pipeline.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. from typing import Any, Dict, Optional, Union import torch @@ -30,15 +31,18 @@ class VisualQuestionAnsweringPipeline(Pipeline): model (MPlugForVisualQuestionAnswering): a model instance preprocessor (MPlugVisualQuestionAnsweringPreprocessor): a preprocessor instance """ - model = model if isinstance( - model, - MPlugForVisualQuestionAnswering) else Model.from_pretrained(model) + model = model if isinstance(model, + Model) else Model.from_pretrained(model) + self.tokenizer = None if preprocessor is None: preprocessor = MPlugVisualQuestionAnsweringPreprocessor( model.model_dir) - model.eval() + if isinstance(model, MPlugForVisualQuestionAnswering): + model.eval() + self.tokenizer = model.tokenizer + else: + model.model.eval() super().__init__(model=model, preprocessor=preprocessor, **kwargs) - self.tokenizer = model.tokenizer def forward(self, inputs: Dict[str, Any], **forward_params) -> Dict[str, Any]: @@ -55,6 +59,8 @@ class VisualQuestionAnsweringPipeline(Pipeline): Returns: Dict[str, str]: the prediction results """ + if self.tokenizer is None: + return inputs replace_tokens_bert = (('[unused0]', ''), ('[PAD]', ''), ('[unused1]', ''), (r' +', ' '), ('[SEP]', ''), ('[unused2]', ''), ('[CLS]', ''), ('[UNK]', '')) diff --git a/modelscope/pipelines/nlp/__init__.py b/modelscope/pipelines/nlp/__init__.py index 561ced1a..e6a35efc 100644 --- a/modelscope/pipelines/nlp/__init__.py +++ b/modelscope/pipelines/nlp/__init__.py @@ -17,6 +17,8 @@ if TYPE_CHECKING: from .translation_pipeline import TranslationPipeline from .word_segmentation_pipeline import WordSegmentationPipeline from .zero_shot_classification_pipeline import ZeroShotClassificationPipeline + from .summarization_pipeline import SummarizationPipeline + from .text_classification_pipeline import TextClassificationPipeline from .text_error_correction_pipeline import TextErrorCorrectionPipeline else: @@ -38,6 +40,8 @@ else: 'named_entity_recognition_pipeline': ['NamedEntityRecognitionPipeline'], 'translation_pipeline': ['TranslationPipeline'], + 'summarization_pipeline': ['SummarizationPipeline'], + 'text_classification_pipeline': ['TextClassificationPipeline'], 'text_error_correction_pipeline': ['TextErrorCorrectionPipeline'] } diff --git a/modelscope/pipelines/nlp/summarization_pipeline.py b/modelscope/pipelines/nlp/summarization_pipeline.py new file mode 100644 index 00000000..148acc06 --- /dev/null +++ b/modelscope/pipelines/nlp/summarization_pipeline.py @@ -0,0 +1,42 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict, Union + +from modelscope.metainfo import Pipelines +from modelscope.pipelines.base import Model, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import OfaPreprocessor, Preprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.summarization, module_name=Pipelines.text_generation) +class SummarizationPipeline(Pipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: [Preprocessor] = None, + **kwargs): + """ + use `model` and `preprocessor` to create a kws pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model) + assert isinstance(model, str) or isinstance(model, Model), \ + 'model must be a single str or OfaForAllTasks' + if isinstance(model, str): + pipe_model = Model.from_pretrained(model) + elif isinstance(model, Model): + pipe_model = model + else: + raise NotImplementedError + pipe_model.model.eval() + if preprocessor is None and pipe_model: + preprocessor = OfaPreprocessor(model_dir=pipe_model.model_dir) + super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/nlp/text_classification_pipeline.py b/modelscope/pipelines/nlp/text_classification_pipeline.py new file mode 100644 index 00000000..f873d6d7 --- /dev/null +++ b/modelscope/pipelines/nlp/text_classification_pipeline.py @@ -0,0 +1,42 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict, Union + +from modelscope.metainfo import Pipelines +from modelscope.pipelines.base import Model, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import OfaPreprocessor, Preprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.text_classification, module_name=Pipelines.text_classification) +class TextClassificationPipeline(Pipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: [Preprocessor] = None, + **kwargs): + """ + use `model` and `preprocessor` to create a kws pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model) + assert isinstance(model, str) or isinstance(model, Model), \ + 'model must be a single str or OfaForAllTasks' + if isinstance(model, str): + pipe_model = Model.from_pretrained(model) + elif isinstance(model, Model): + pipe_model = model + else: + raise NotImplementedError + pipe_model.model.eval() + if preprocessor is None and pipe_model: + preprocessor = OfaPreprocessor(model_dir=pipe_model.model_dir) + super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/preprocessors/__init__.py b/modelscope/preprocessors/__init__.py index 38fe3b9a..9d991146 100644 --- a/modelscope/preprocessors/__init__.py +++ b/modelscope/preprocessors/__init__.py @@ -14,7 +14,7 @@ if TYPE_CHECKING: ImageInstanceSegmentationPreprocessor, ImageDenoisePreprocessor) from .kws import WavToLists - from .multi_modal import (OfaImageCaptionPreprocessor, + from .multi_modal import (OfaPreprocessor, MPlugVisualQuestionAnsweringPreprocessor) from .nlp import (Tokenize, SequenceClassificationPreprocessor, TextGenerationPreprocessor, @@ -41,10 +41,8 @@ else: 'ImageInstanceSegmentationPreprocessor', 'ImageDenoisePreprocessor' ], 'kws': ['WavToLists'], - 'multi_modal': [ - 'OfaImageCaptionPreprocessor', - 'MPlugVisualQuestionAnsweringPreprocessor' - ], + 'multi_modal': + ['OfaPreprocessor', 'MPlugVisualQuestionAnsweringPreprocessor'], 'nlp': [ 'Tokenize', 'SequenceClassificationPreprocessor', 'TextGenerationPreprocessor', 'TokenClassificationPreprocessor', diff --git a/modelscope/preprocessors/multi_modal.py b/modelscope/preprocessors/multi_modal.py index 56bcfcd1..055c4efb 100644 --- a/modelscope/preprocessors/multi_modal.py +++ b/modelscope/preprocessors/multi_modal.py @@ -4,26 +4,25 @@ from typing import Any, Dict, Union import torch from PIL import Image -from torchvision import transforms from modelscope.hub.snapshot_download import snapshot_download from modelscope.metainfo import Preprocessors -from modelscope.models.multi_modal.ofa import OFATokenizer -from modelscope.utils.constant import Fields -from modelscope.utils.type_assert import type_assert +from modelscope.utils.config import Config +from modelscope.utils.constant import Fields, ModelFile, Tasks from .base import Preprocessor from .builder import PREPROCESSORS -from .image import load_image +from .ofa import * # noqa +from .ofa.utils.collate import collate_fn __all__ = [ - 'OfaImageCaptionPreprocessor', + 'OfaPreprocessor', 'MPlugVisualQuestionAnsweringPreprocessor', ] @PREPROCESSORS.register_module( Fields.multi_modal, module_name=Preprocessors.ofa_image_caption) -class OfaImageCaptionPreprocessor(Preprocessor): +class OfaPreprocessor(Preprocessor): def __init__(self, model_dir: str, *args, **kwargs): """preprocess the data via the vocab.txt from the `model_dir` path @@ -32,41 +31,28 @@ class OfaImageCaptionPreprocessor(Preprocessor): model_dir (str): model path """ super().__init__(*args, **kwargs) + preprocess_mapping = { + Tasks.image_captioning: OfaImageCaptioningPreprocessor, + Tasks.visual_grounding: OfaVisualGroundingPreprocessor, + Tasks.visual_question_answering: + OfaVisualQuestionAnsweringPreprocessor, + Tasks.visual_entailment: OfaVisualEntailmentPreprocessor, + Tasks.image_classification: OfaImageClassificationPreprocessor, + Tasks.text_classification: OfaTextClassificationPreprocessor, + Tasks.summarization: OfaSummarizationPreprocessor + } model_dir = model_dir if osp.exists(model_dir) else snapshot_download( model_dir) - self.tokenizer = OFATokenizer.from_pretrained(model_dir) - self.tokenizer.add_tokens([''.format(i) for i in range(8192)]) - self.tokenizer.add_tokens([''.format(i) for i in range(1000)]) + cfg = Config.from_file(osp.join(model_dir, ModelFile.CONFIGURATION)) + self.preprocess = preprocess_mapping[cfg.task](cfg, model_dir) + self.tokenizer = self.preprocess.tokenizer - # Initialize transform - mean = [0.5, 0.5, 0.5] - std = [0.5, 0.5, 0.5] - patch_image_size = 480 - self.patch_resize_transform = transforms.Compose([ - lambda image: image.convert('RGB'), - transforms.Resize((patch_image_size, patch_image_size), - interpolation=Image.BICUBIC), - transforms.ToTensor(), - transforms.Normalize(mean=mean, std=std), - ]) - - @type_assert(object, (str, tuple, Image.Image)) - def __call__(self, data: Union[str, tuple]) -> Dict[str, Any]: - if isinstance(data, Image.Image): - patch_image = self.patch_resize_transform(data).unsqueeze(0) - else: - patch_image = self.patch_resize_transform( - load_image(data)).unsqueeze(0) - text = ' what does the image describe?' - inputs = self.tokenizer([text], max_length=1024, - return_tensors='pt')['input_ids'] - sample = dict() - sample['net_input'] = { - 'input_ids': inputs, - 'patch_images': patch_image, - 'patch_masks': torch.tensor([True]) - } - return sample + def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: + sample = self.preprocess(data) + sample['sample'] = data + return collate_fn([sample], + pad_idx=self.tokenizer.pad_token_id, + eos_idx=self.tokenizer.eos_token_id) @PREPROCESSORS.register_module( diff --git a/modelscope/preprocessors/ofa/__init__.py b/modelscope/preprocessors/ofa/__init__.py new file mode 100644 index 00000000..44954668 --- /dev/null +++ b/modelscope/preprocessors/ofa/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .image_captioning import OfaImageCaptioningPreprocessor +from .image_classification import OfaImageClassificationPreprocessor +from .summarization import OfaSummarizationPreprocessor +from .text_classification import OfaTextClassificationPreprocessor +from .visual_entailment import OfaVisualEntailmentPreprocessor +from .visual_grounding import OfaVisualGroundingPreprocessor +from .visual_question_answering import OfaVisualQuestionAnsweringPreprocessor diff --git a/modelscope/preprocessors/ofa/base.py b/modelscope/preprocessors/ofa/base.py new file mode 100644 index 00000000..8f53dbf7 --- /dev/null +++ b/modelscope/preprocessors/ofa/base.py @@ -0,0 +1,117 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import re +from os import path as osp + +import json +import numpy as np +import torch + +from modelscope.models.multi_modal.ofa import OFATokenizer +from modelscope.utils.trie import Trie +from .utils.random_help import set_torch_seed + + +class OfaBasePreprocessor: + + def __init__(self, cfg, model_dir): + """preprocess the data via the vocab.txt from the `model_dir` path + + Args: + cfg(modelscope.utils.config.ConfigDict) : model config + model_dir (str): model path + """ + self.cfg = cfg + tokenizer = OFATokenizer.from_pretrained(model_dir) + tokenizer.add_tokens([''.format(i) for i in range(8192)]) + tokenizer.add_tokens([''.format(i) for i in range(1000)]) + self.tokenizer = tokenizer + self.bos_item = torch.LongTensor([tokenizer.bos_token_id]) + self.pad_item = torch.LongTensor([tokenizer.pad_token_id]) + self.eos_item = torch.LongTensor([tokenizer.eos_token_id]) + self.tgt_dict = self.src_dict = { + value: key + for key, value in tokenizer.get_vocab().items() + } + self.max_src_length = cfg.model.get('max_src_length', 256) + self.max_image_size = cfg.model.get('max_image_size', 512) + self.language = self.cfg.model.get('language', 'en') + self.prompt_type = self.cfg.model.get('prompt_type', 'none') + seed = self.cfg.model.get('seed', 7) + np.random.seed(seed) + set_torch_seed(seed) + imagenet_default_mean_and_std = self.cfg.model.get( + 'imagenet_default_mean_and_std', False) + if imagenet_default_mean_and_std: + self.mean = [0.485, 0.456, 0.406] + self.std = [0.229, 0.224, 0.225] + else: + self.mean = [0.5, 0.5, 0.5] + self.std = [0.5, 0.5, 0.5] + self.patch_image_size = self.cfg.model.get('patch_image_size', 480) + self.constraint_trie = None + self.index2ans = {} + if self.cfg.model.get('answer2label', False): + ans2label_file = osp.join(model_dir, self.cfg.model.answer2label) + ans2label_dict = json.load(open(ans2label_file, 'r')) + self.constraint_trie = Trie(tokenizer.eos_token_id) + for i, answer in enumerate(ans2label_dict.keys()): + answer_item = tokenizer( + ' ' + answer, + return_tensors='pt', + add_special_tokens=False).input_ids.squeeze(0) + self.constraint_trie.insert([tokenizer.bos_token_id] + + answer_item.tolist() + + [tokenizer.eos_token_id]) + + def get_inputs(self, text, add_bos=True, add_eos=True): + inputs = self.tokenizer( + text, + max_length=self.max_src_length, + add_special_tokens=False, + return_tensors='pt')['input_ids'].squeeze(0) + if add_bos: + inputs = torch.cat([self.bos_item, inputs]) + if add_eos: + inputs = torch.cat([inputs, self.eos_item]) + return inputs + + @staticmethod + def pre_caption(caption, max_words=None): + caption = caption.lower().lstrip(',.!?*#:;~').replace('-', ' ')\ + .replace('/', ' ').replace('', 'person') + + caption = re.sub( + r'\s{2,}', + ' ', + caption, + ) + caption = caption.rstrip('\n') + caption = caption.strip(' ') + + # truncate caption + caption_words = caption.split(' ') + if max_words is not None and len(caption_words) > max_words: + caption = ' '.join(caption_words[:max_words]) + + return caption + + @staticmethod + def pre_question(question, max_ques_words): + question = question.lower().lstrip(',.!?*#:;~').replace('-', + ' ').replace( + '/', ' ') + + question = re.sub( + r'\s{2,}', + ' ', + question, + ) + question = question.rstrip('\n') + question = question.strip(' ') + + # truncate question + question_words = question.split(' ') + if len(question_words) > max_ques_words: + question = ' '.join(question_words[:max_ques_words]) + + return question diff --git a/modelscope/preprocessors/ofa/image_captioning.py b/modelscope/preprocessors/ofa/image_captioning.py new file mode 100644 index 00000000..264c8e04 --- /dev/null +++ b/modelscope/preprocessors/ofa/image_captioning.py @@ -0,0 +1,42 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict, Union + +import torch +from PIL import Image +from torchvision import transforms + +from modelscope.preprocessors.image import load_image +from .base import OfaBasePreprocessor + + +class OfaImageCaptioningPreprocessor(OfaBasePreprocessor): + + def __init__(self, cfg, model_dir): + """preprocess the data via the vocab.txt from the `model_dir` path + + Args: + cfg(modelscope.utils.config.ConfigDict) : model config + model_dir (str): model path + """ + super(OfaImageCaptioningPreprocessor, self).__init__(cfg, model_dir) + # Initialize transform + self.patch_resize_transform = transforms.Compose([ + lambda image: image.convert('RGB'), + transforms.Resize((self.patch_image_size, self.patch_image_size), + interpolation=Image.BICUBIC), + transforms.ToTensor(), + transforms.Normalize(mean=self.mean, std=self.std), + ]) + + def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: + image = data['image'] if isinstance( + data['image'], Image.Image) else load_image(data['image']) + patch_image = self.patch_resize_transform(image) + prompt = self.cfg.model.get('prompt', ' what does the image describe?') + inputs = self.get_inputs(prompt) + sample = { + 'source': inputs, + 'patch_image': patch_image, + 'patch_mask': torch.tensor([True]) + } + return sample diff --git a/modelscope/preprocessors/ofa/image_classification.py b/modelscope/preprocessors/ofa/image_classification.py new file mode 100644 index 00000000..30289613 --- /dev/null +++ b/modelscope/preprocessors/ofa/image_classification.py @@ -0,0 +1,43 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict + +import torch +from PIL import Image +from torchvision import transforms + +from modelscope.preprocessors.image import load_image +from .base import OfaBasePreprocessor + + +class OfaImageClassificationPreprocessor(OfaBasePreprocessor): + + def __init__(self, cfg, model_dir): + """preprocess the data via the vocab.txt from the `model_dir` path + + Args: + cfg(modelscope.utils.config.ConfigDict) : model config + model_dir (str): model path + """ + super(OfaImageClassificationPreprocessor, + self).__init__(cfg, model_dir) + # Initialize transform + self.patch_resize_transform = transforms.Compose([ + lambda image: image.convert('RGB'), + transforms.Resize((self.patch_image_size, self.patch_image_size), + interpolation=Image.BICUBIC), + transforms.ToTensor(), + transforms.Normalize(mean=self.mean, std=self.std), + ]) + + def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: + image = data['image'] if isinstance( + data['image'], Image.Image) else load_image(data['image']) + patch_image = self.patch_resize_transform(image) + prompt = self.cfg.model.get('prompt', ' what does the image describe?') + inputs = self.get_inputs(prompt) + sample = { + 'source': inputs, + 'patch_image': patch_image, + 'patch_mask': torch.tensor([True]) + } + return sample diff --git a/modelscope/preprocessors/ofa/summarization.py b/modelscope/preprocessors/ofa/summarization.py new file mode 100644 index 00000000..fd5113cd --- /dev/null +++ b/modelscope/preprocessors/ofa/summarization.py @@ -0,0 +1,37 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict + +from .base import OfaBasePreprocessor + + +class OfaSummarizationPreprocessor(OfaBasePreprocessor): + + def __init__(self, cfg, model_dir): + """preprocess the data via the vocab.txt from the `model_dir` path + + Args: + cfg(modelscope.utils.config.ConfigDict) : model config + model_dir (str): model path + """ + super(OfaSummarizationPreprocessor, self).__init__(cfg, model_dir) + + def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: + source = super().pre_caption( + data['text'], max_words=self.max_src_length) + source = source.strip()[:self.max_src_length] + source = source.replace('[unk]', 'unk').replace('', 'unk') + prompt = self.cfg.model.get( + 'prompt', ' " {} " Summarize the article with a title: ') + text = prompt.format(source) + inputs = self.get_inputs(text) + if self.prompt_type == 'none': + decoder_prompt = self.bos_item + elif self.prompt_type == 'prev_output': + decoder_prompt = inputs[:-1] + else: + raise NotImplementedError + sample = { + 'source': inputs, + 'decoder_prompt': decoder_prompt, + } + return sample diff --git a/modelscope/preprocessors/ofa/text_classification.py b/modelscope/preprocessors/ofa/text_classification.py new file mode 100644 index 00000000..1a3f84fd --- /dev/null +++ b/modelscope/preprocessors/ofa/text_classification.py @@ -0,0 +1,38 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict + +from .base import OfaBasePreprocessor + + +class OfaTextClassificationPreprocessor(OfaBasePreprocessor): + + def __init__(self, cfg, model_dir): + """preprocess the data via the vocab.txt from the `model_dir` path + + Args: + cfg(modelscope.utils.config.ConfigDict) : model config + model_dir (str): model path + """ + super(OfaTextClassificationPreprocessor, self).__init__(cfg, model_dir) + + def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: + text1 = ' '.join( + data['text'].lower().strip().split()[:self.max_src_length]) + text2 = ' '.join( + data['text2'].lower().strip().split()[:self.max_src_length]) + prompt = ' can text1 " {} " imply text2 " {} "?' + text = prompt.format(text1, text2) + inputs = self.get_inputs(text) + if self.prompt_type == 'none': + decoder_prompt = self.bos_item + elif self.prompt_type == 'src': + decoder_prompt = inputs + elif self.prompt_type == 'prev_output': + decoder_prompt = inputs[:-1] + else: + raise NotImplementedError + sample = { + 'source': inputs, + 'decoder_prompt': decoder_prompt, + } + return sample diff --git a/modelscope/preprocessors/ofa/utils/__init__.py b/modelscope/preprocessors/ofa/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/preprocessors/ofa/utils/collate.py b/modelscope/preprocessors/ofa/utils/collate.py new file mode 100644 index 00000000..a473335b --- /dev/null +++ b/modelscope/preprocessors/ofa/utils/collate.py @@ -0,0 +1,109 @@ +import numpy as np +import torch + + +def collate_fn(samples, pad_idx, eos_idx): + if len(samples) == 0: + return {} + + def merge(key): + return collate_tokens([s[key] for s in samples], + pad_idx, + eos_idx=eos_idx) + + src_tokens = merge('source') + + batch = { + 'nsentences': len(samples), + 'net_input': { + 'input_ids': src_tokens, + }, + } + if samples[0].get('id', None) is not None: + batch['id'] = np.array([s.get['id'] for s in samples]) + if samples[0].get('target', None) is not None: + batch['target'] = merge('target') + tgt_lengths = torch.LongTensor( + [s['target'].ne(pad_idx).long().sum() for s in samples]) + ntokens = tgt_lengths.sum().item() + batch['ntokens'] = ntokens + if samples[0].get('prev_output_tokens', None) is not None: + batch['net_input']['decoder_input_ids'] = merge('prev_output_tokens') + if samples[0].get('patch_image', None) is not None: + batch['net_input']['patch_images'] = torch.stack( + [sample['patch_image'] for sample in samples], dim=0) + if samples[0].get('patch_mask', None) is not None: + batch['net_input']['patch_masks'] = torch.cat( + [sample['patch_mask'] for sample in samples]) + # image generation + if samples[0].get('code_mask', None) is not None: + batch['net_input']['code_masks'] = torch.cat( + [sample['code_mask'] for sample in samples]) + if samples[0].get('code_image', None) is not None: + batch['code_images'] = torch.cat( + [sample['code_image'] for sample in samples]) + # For classification tasks (i.e., VQA, SNLI-VE, GLUE) + if samples[0].get('conf', None) is not None: + batch['conf'] = torch.cat([s['conf'] for s in samples], dim=0) + if samples[0].get('ref_dict', None) is not None: + batch['ref_dict'] = np.array([s['ref_dict'] for s in samples]) + if samples[0].get('constraint_mask', None) is not None: + batch['constraint_masks'] = merge('constraint_mask') + if samples[0].get('decoder_prompt', None) is not None: + batch['decoder_prompts'] = np.array( + [s['decoder_prompt'].tolist() for s in samples]) + # For detection and visual grounding + if samples[0].get('w_resize_ratio', None) is not None: + batch['w_resize_ratios'] = torch.stack( + [s['w_resize_ratio'] for s in samples], dim=0) + if samples[0].get('h_resize_ratio', None) is not None: + batch['h_resize_ratios'] = torch.stack( + [s['h_resize_ratio'] for s in samples], dim=0) + if samples[0].get('region_coord', None) is not None: + batch['region_coords'] = torch.stack( + [s['region_coord'] for s in samples], dim=0) + if samples[0].get('sample', None) is not None: + batch['samples'] = [s['sample'] for s in samples] + return batch + + +def collate_tokens( + values, + pad_idx, + eos_idx=None, + left_pad=False, + move_eos_to_beginning=False, + pad_to_length=None, + pad_to_multiple=1, + pad_to_bsz=None, +): + """Convert a list of 1d tensors into a padded 2d tensor.""" + size = max(v.size(0) for v in values) + size = size if pad_to_length is None else max(size, pad_to_length) + if pad_to_multiple != 1 and size % pad_to_multiple != 0: + size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple) + + def copy_tensor(src, dst): + assert dst.numel() == src.numel() + if move_eos_to_beginning: + if eos_idx is None: + # if no eos_idx is specified, then use the last token in src + dst[0] = src[-1] + else: + dst[0] = eos_idx + dst[1:] = src[:-1] + else: + dst.copy_(src) + + if values[0].dim() == 1: + res = values[0].new(len(values), size).fill_(pad_idx) + elif values[0].dim() == 2: + assert move_eos_to_beginning is False + res = values[0].new(len(values), size, + values[0].size(1)).fill_(pad_idx) + else: + raise NotImplementedError + + for i, v in enumerate(values): + copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)]) + return res diff --git a/modelscope/preprocessors/ofa/utils/random_help.py b/modelscope/preprocessors/ofa/utils/random_help.py new file mode 100644 index 00000000..77f4df3f --- /dev/null +++ b/modelscope/preprocessors/ofa/utils/random_help.py @@ -0,0 +1,42 @@ +import torch + +try: + import torch_xla.core.xla_model as xm +except ImportError: + xm = None + + +def get_rng_state(): + state = {'torch_rng_state': torch.get_rng_state()} + if xm is not None: + state['xla_rng_state'] = xm.get_rng_state() + if torch.cuda.is_available(): + state['cuda_rng_state'] = torch.cuda.get_rng_state() + return state + + +def set_rng_state(state): + torch.set_rng_state(state['torch_rng_state']) + if xm is not None: + xm.set_rng_state(state['xla_rng_state']) + if torch.cuda.is_available(): + torch.cuda.set_rng_state(state['cuda_rng_state']) + + +class set_torch_seed(object): + + def __init__(self, seed): + assert isinstance(seed, int) + self.rng_state = get_rng_state() + + torch.manual_seed(seed) + if xm is not None: + xm.set_rng_state(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + def __enter__(self): + return self + + def __exit__(self, *exc): + set_rng_state(self.rng_state) diff --git a/modelscope/preprocessors/ofa/utils/transforms.py b/modelscope/preprocessors/ofa/utils/transforms.py new file mode 100644 index 00000000..3fd312c6 --- /dev/null +++ b/modelscope/preprocessors/ofa/utils/transforms.py @@ -0,0 +1,557 @@ +# Copyright 2022 The OFA-Sys Team. +# All rights reserved. +# This source code is licensed under the Apache 2.0 license +# found in the LICENSE file in the root directory. + +import random + +import numpy as np +import torch +import torchvision.transforms as T +import torchvision.transforms.functional as F +from PIL import Image + + +def crop(image, target, region, delete=True): + cropped_image = F.crop(image, *region) + + target = target.copy() + i, j, h, w = region + + # should we do something wrt the original size? + target['size'] = torch.tensor([h, w]) + + fields = ['labels', 'area'] + + if 'boxes' in target: + boxes = target['boxes'] + max_size = torch.as_tensor([w, h], dtype=torch.float32) + cropped_boxes = boxes - torch.as_tensor([j, i, j, i]) + cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) + cropped_boxes = cropped_boxes.clamp(min=0) + area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1) + target['boxes'] = cropped_boxes.reshape(-1, 4) + target['area'] = area + fields.append('boxes') + + if 'polygons' in target: + polygons = target['polygons'] + num_polygons = polygons.shape[0] + max_size = torch.as_tensor([w, h], dtype=torch.float32) + start_coord = torch.cat([ + torch.tensor([j, i], dtype=torch.float32) + for _ in range(polygons.shape[1] // 2)], dim=0) # yapf: disable# + cropped_boxes = polygons - start_coord + cropped_boxes = torch.min( + cropped_boxes.reshape(num_polygons, -1, 2), max_size) + cropped_boxes = cropped_boxes.clamp(min=0) + target['polygons'] = cropped_boxes.reshape(num_polygons, -1) + fields.append('polygons') + + if 'masks' in target: + # FIXME should we update the area here if there are no boxes? + target['masks'] = target['masks'][:, i:i + h, j:j + w] + fields.append('masks') + + # remove elements for which the boxes or masks that have zero area + if delete and ('boxes' in target or 'masks' in target): + # favor boxes selection when defining which elements to keep + # this is compatible with previous implementation + if 'boxes' in target: + cropped_boxes = target['boxes'].reshape(-1, 2, 2) + keep = torch.all( + cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1) + else: + keep = target['masks'].flatten(1).any(1) + + for field in fields: + target[field] = target[field][keep.tolist()] + + return cropped_image, target + + +def hflip(image, target): + flipped_image = F.hflip(image) + w, h = image.size + target = target.copy() + if 'boxes' in target: + boxes = target['boxes'] + boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor( + [-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0]) + target['boxes'] = boxes + + if 'polygons' in target: + polygons = target['polygons'] + num_polygons = polygons.shape[0] + polygons = polygons.reshape(num_polygons, -1, 2) * torch.as_tensor( + [-1, 1]) + torch.as_tensor([w, 0]) + target['polygons'] = polygons + + if 'masks' in target: + target['masks'] = target['masks'].flip(-1) + + return flipped_image, target + + +def resize(image, target, size, max_size=None): + # size can be min_size (scalar) or (w, h) tuple + + def get_size_with_aspect_ratio(image_size, size, max_size=None): + w, h = image_size + + if (w <= h and w == size) or (h <= w and h == size): + if max_size is not None: + max_size = int(max_size) + h = min(h, max_size) + w = min(w, max_size) + return (h, w) + + if w < h: + ow = size + oh = int(size * h / w) + else: + oh = size + ow = int(size * w / h) + + if max_size is not None: + max_size = int(max_size) + oh = min(oh, max_size) + ow = min(ow, max_size) + + return (oh, ow) + + def get_size(image_size, size, max_size=None): + if isinstance(size, (list, tuple)): + return size[::-1] + else: + return get_size_with_aspect_ratio(image_size, size, max_size) + + size = get_size(image.size, size, max_size) + rescaled_image = F.resize(image, size, interpolation=Image.BICUBIC) + + if target is None: + return rescaled_image + + ratios = tuple( + float(s) / float(s_orig) + for s, s_orig in zip(rescaled_image.size, image.size)) + ratio_width, ratio_height = ratios + + target = target.copy() + if 'boxes' in target: + boxes = target['boxes'] + scaled_boxes = boxes * torch.as_tensor( + [ratio_width, ratio_height, ratio_width, ratio_height]) + target['boxes'] = scaled_boxes + + if 'polygons' in target: + polygons = target['polygons'] + scaled_ratio = torch.cat([ + torch.tensor([ratio_width, ratio_height]) + for _ in range(polygons.shape[1] // 2)], dim=0) # yapf: disable + scaled_polygons = polygons * scaled_ratio + target['polygons'] = scaled_polygons + + if 'area' in target: + area = target['area'] + scaled_area = area * (ratio_width * ratio_height) + target['area'] = scaled_area + + h, w = size + target['size'] = torch.tensor([h, w]) + + if 'masks' in target: + assert False + + return rescaled_image, target + + +class CenterCrop(object): + + def __init__(self, size): + self.size = size + + def __call__(self, img, target): + image_width, image_height = img.size + crop_height, crop_width = self.size + crop_top = int(round((image_height - crop_height) / 2.)) + crop_left = int(round((image_width - crop_width) / 2.)) + return crop(img, target, + (crop_top, crop_left, crop_height, crop_width)) + + +class ObjectCenterCrop(object): + + def __init__(self, size): + self.size = size + + def __call__(self, img, target): + image_width, image_height = img.size + crop_height, crop_width = self.size + + x0 = float(target['boxes'][0][0]) + y0 = float(target['boxes'][0][1]) + x1 = float(target['boxes'][0][2]) + y1 = float(target['boxes'][0][3]) + + center_x = (x0 + x1) / 2 + center_y = (y0 + y1) / 2 + crop_left = max( + center_x - crop_width / 2 + + min(image_width - center_x - crop_width / 2, 0), 0) + crop_top = max( + center_y - crop_height / 2 + + min(image_height - center_y - crop_height / 2, 0), 0) + + return crop( + img, + target, (crop_top, crop_left, crop_height, crop_width), + delete=False) + + +class RandomHorizontalFlip(object): + + def __init__(self, p=0.5): + self.p = p + + def __call__(self, img, target): + if random.random() < self.p: + return hflip(img, target) + return img, target + + +class RandomResize(object): + + def __init__(self, sizes, max_size=None, equal=False): + assert isinstance(sizes, (list, tuple)) + self.sizes = sizes + self.max_size = max_size + self.equal = equal + + def __call__(self, img, target=None): + size = random.choice(self.sizes) + if self.equal: + return resize(img, target, size, size) + else: + return resize(img, target, size, self.max_size) + + +class ToTensor(object): + + def __call__(self, img, target): + return F.to_tensor(img), target + + +class Normalize(object): + + def __init__(self, mean, std, max_image_size=512): + self.mean = mean + self.std = std + self.max_image_size = max_image_size + + def __call__(self, image, target=None): + image = F.normalize(image, mean=self.mean, std=self.std) + if target is None: + return image, None + target = target.copy() + # h, w = image.shape[-2:] + h, w = target['size'][0], target['size'][1] + if 'boxes' in target: + boxes = target['boxes'] + boxes = boxes / self.max_image_size + target['boxes'] = boxes + if 'polygons' in target: + polygons = target['polygons'] + scale = torch.cat([ + torch.tensor([w, h], dtype=torch.float32) + for _ in range(polygons.shape[1] // 2)], dim=0) # yapf: disable + polygons = polygons / scale + target['polygons'] = polygons + return image, target + + +class Compose(object): + + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, image, target): + for t in self.transforms: + image, target = t(image, target) + return image, target + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + for t in self.transforms: + format_string += '\n' + format_string += ' {0}'.format(t) + format_string += '\n)' + return format_string + + +class LargeScaleJitter(object): + """ + implementation of large scale jitter from copy_paste + """ + + def __init__(self, output_size=512, aug_scale_min=0.3, aug_scale_max=2.0): + self.desired_size = torch.tensor([output_size]) + self.aug_scale_min = aug_scale_min + self.aug_scale_max = aug_scale_max + + def rescale_target(self, scaled_size, image_size, target): + # compute rescaled targets + image_scale = scaled_size / image_size + ratio_height, ratio_width = image_scale + + target = target.copy() + target['size'] = scaled_size + + if 'boxes' in target: + boxes = target['boxes'] + scaled_boxes = boxes * torch.as_tensor( + [ratio_width, ratio_height, ratio_width, ratio_height]) + target['boxes'] = scaled_boxes + + if 'area' in target: + area = target['area'] + scaled_area = area * (ratio_width * ratio_height) + target['area'] = scaled_area + + if 'masks' in target: + assert False + masks = target['masks'] + # masks = interpolate( + # masks[:, None].float(), scaled_size, mode="nearest")[:, 0] > 0.5 + target['masks'] = masks + return target + + def crop_target(self, region, target): + i, j, h, w = region + fields = ['labels', 'area'] + + target = target.copy() + target['size'] = torch.tensor([h, w]) + + if 'boxes' in target: + boxes = target['boxes'] + max_size = torch.as_tensor([w, h], dtype=torch.float32) + cropped_boxes = boxes - torch.as_tensor([j, i, j, i]) + cropped_boxes = torch.min( + cropped_boxes.reshape(-1, 2, 2), max_size) + cropped_boxes = cropped_boxes.clamp(min=0) + area = (cropped_boxes[:, 1, :] + - cropped_boxes[:, 0, :]).prod(dim=1) + target['boxes'] = cropped_boxes.reshape(-1, 4) + target['area'] = area + fields.append('boxes') + + if 'masks' in target: + # FIXME should we update the area here if there are no boxes? + target['masks'] = target['masks'][:, i:i + h, j:j + w] + fields.append('masks') + + # remove elements for which the boxes or masks that have zero area + if 'boxes' in target or 'masks' in target: + # favor boxes selection when defining which elements to keep + # this is compatible with previous implementation + if 'boxes' in target: + cropped_boxes = target['boxes'].reshape(-1, 2, 2) + keep = torch.all( + cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1) + else: + keep = target['masks'].flatten(1).any(1) + + for field in fields: + target[field] = target[field][keep.tolist()] + return target + + def pad_target(self, padding, target): + target = target.copy() + if 'masks' in target: + target['masks'] = torch.nn.functional.pad( + target['masks'], (0, padding[1], 0, padding[0])) + return target + + def __call__(self, image, target=None): + image_size = image.size + image_size = torch.tensor(image_size[::-1]) + + random_scale = torch.rand(1) * ( + self.aug_scale_max - self.aug_scale_min) + self.aug_scale_min + scaled_size = (random_scale * self.desired_size).round() + + scale = torch.maximum(scaled_size / image_size[0], + scaled_size / image_size[1]) + scaled_size = (image_size * scale).round().int() + + scaled_image = F.resize( + image, scaled_size.tolist(), interpolation=Image.BICUBIC) + + if target is not None: + target = self.rescale_target(scaled_size, image_size, target) + + # randomly crop or pad images + if random_scale >= 1: + # Selects non-zero random offset (x, y) if scaled image is larger than desired_size. + max_offset = scaled_size - self.desired_size + offset = (max_offset * torch.rand(2)).floor().int() + region = (offset[0].item(), offset[1].item(), + self.desired_size[0].item(), self.desired_size[0].item()) + output_image = F.crop(scaled_image, *region) + if target is not None: + target = self.crop_target(region, target) + else: + assert False + padding = self.desired_size - scaled_size + output_image = F.pad(scaled_image, + [0, 0, padding[1].item(), padding[0].item()]) + if target is not None: + target = self.pad_target(padding, target) + + return output_image, target + + +class OriginLargeScaleJitter(object): + """ + implementation of large scale jitter from copy_paste + """ + + def __init__(self, output_size=512, aug_scale_min=0.3, aug_scale_max=2.0): + self.desired_size = torch.tensor(output_size) + self.aug_scale_min = aug_scale_min + self.aug_scale_max = aug_scale_max + + def rescale_target(self, scaled_size, image_size, target): + # compute rescaled targets + image_scale = scaled_size / image_size + ratio_height, ratio_width = image_scale + + target = target.copy() + target['size'] = scaled_size + + if 'boxes' in target: + boxes = target['boxes'] + scaled_boxes = boxes * torch.as_tensor( + [ratio_width, ratio_height, ratio_width, ratio_height]) + target['boxes'] = scaled_boxes + + if 'area' in target: + area = target['area'] + scaled_area = area * (ratio_width * ratio_height) + target['area'] = scaled_area + + if 'masks' in target: + assert False + masks = target['masks'] + # masks = interpolate( + # masks[:, None].float(), scaled_size, mode="nearest")[:, 0] > 0.5 + target['masks'] = masks + return target + + def crop_target(self, region, target): + i, j, h, w = region + fields = ['labels', 'area'] + + target = target.copy() + target['size'] = torch.tensor([h, w]) + + if 'boxes' in target: + boxes = target['boxes'] + max_size = torch.as_tensor([w, h], dtype=torch.float32) + cropped_boxes = boxes - torch.as_tensor([j, i, j, i]) + cropped_boxes = torch.min( + cropped_boxes.reshape(-1, 2, 2), max_size) + cropped_boxes = cropped_boxes.clamp(min=0) + area = (cropped_boxes[:, 1, :] + - cropped_boxes[:, 0, :]).prod(dim=1) + target['boxes'] = cropped_boxes.reshape(-1, 4) + target['area'] = area + fields.append('boxes') + + if 'masks' in target: + # FIXME should we update the area here if there are no boxes? + target['masks'] = target['masks'][:, i:i + h, j:j + w] + fields.append('masks') + + # remove elements for which the boxes or masks that have zero area + if 'boxes' in target or 'masks' in target: + # favor boxes selection when defining which elements to keep + # this is compatible with previous implementation + if 'boxes' in target: + cropped_boxes = target['boxes'].reshape(-1, 2, 2) + keep = torch.all( + cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1) + else: + keep = target['masks'].flatten(1).any(1) + + for field in fields: + target[field] = target[field][keep.tolist()] + return target + + def pad_target(self, padding, target): + target = target.copy() + if 'masks' in target: + target['masks'] = torch.nn.functional.pad( + target['masks'], (0, padding[1], 0, padding[0])) + return target + + def __call__(self, image, target=None): + image_size = image.size + image_size = torch.tensor(image_size[::-1]) + + out_desired_size = (self.desired_size * image_size + / max(image_size)).round().int() + + random_scale = torch.rand(1) * ( + self.aug_scale_max - self.aug_scale_min) + self.aug_scale_min + scaled_size = (random_scale * self.desired_size).round() + + scale = torch.minimum(scaled_size / image_size[0], + scaled_size / image_size[1]) + scaled_size = (image_size * scale).round().int() + + scaled_image = F.resize(image, scaled_size.tolist()) + + if target is not None: + target = self.rescale_target(scaled_size, image_size, target) + + # randomly crop or pad images + if random_scale > 1: + # Selects non-zero random offset (x, y) if scaled image is larger than desired_size. + max_offset = scaled_size - out_desired_size + offset = (max_offset * torch.rand(2)).floor().int() + region = (offset[0].item(), offset[1].item(), + out_desired_size[0].item(), out_desired_size[1].item()) + output_image = F.crop(scaled_image, *region) + if target is not None: + target = self.crop_target(region, target) + else: + padding = out_desired_size - scaled_size + output_image = F.pad(scaled_image, + [0, 0, padding[1].item(), padding[0].item()]) + if target is not None: + target = self.pad_target(padding, target) + + return output_image, target + + +class RandomDistortion(object): + """ + Distort image w.r.t hue, saturation and exposure. + """ + + def __init__(self, + brightness=0, + contrast=0, + saturation=0, + hue=0, + prob=0.5): + self.prob = prob + self.tfm = T.ColorJitter(brightness, contrast, saturation, hue) + + def __call__(self, img, target=None): + if np.random.random() < self.prob: + return self.tfm(img), target + else: + return img, target diff --git a/modelscope/preprocessors/ofa/utils/vision_helper.py b/modelscope/preprocessors/ofa/utils/vision_helper.py new file mode 100644 index 00000000..518b110a --- /dev/null +++ b/modelscope/preprocessors/ofa/utils/vision_helper.py @@ -0,0 +1,357 @@ +# Copyright 2022 The OFA-Sys Team. +# All rights reserved. +# This source code is licensed under the Apache 2.0 license +# found in the LICENSE file in the root directory. + +import cv2 +import numpy as np + + +def identity_func(img): + return img + + +def autocontrast_func(img, cutoff=0): + ''' + same output as PIL.ImageOps.autocontrast + ''' + n_bins = 256 + + def tune_channel(ch): + n = ch.size + cut = cutoff * n // 100 + if cut == 0: + high, low = ch.max(), ch.min() + else: + hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) + low = np.argwhere(np.cumsum(hist) > cut) + low = 0 if low.shape[0] == 0 else low[0] + high = np.argwhere(np.cumsum(hist[::-1]) > cut) + high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0] + if high <= low: + table = np.arange(n_bins) + else: + scale = (n_bins - 1) / (high - low) + offset = -low * scale + table = np.arange(n_bins) * scale + offset + table[table < 0] = 0 + table[table > n_bins - 1] = n_bins - 1 + table = table.clip(0, 255).astype(np.uint8) + return table[ch] + + channels = [tune_channel(ch) for ch in cv2.split(img)] + out = cv2.merge(channels) + return out + + +def equalize_func(img): + ''' + same output as PIL.ImageOps.equalize + PIL's implementation is different from cv2.equalize + ''' + n_bins = 256 + + def tune_channel(ch): + hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) + non_zero_hist = hist[hist != 0].reshape(-1) + step = np.sum(non_zero_hist[:-1]) // (n_bins - 1) + if step == 0: + return ch + n = np.empty_like(hist) + n[0] = step // 2 + n[1:] = hist[:-1] + table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8) + return table[ch] + + channels = [tune_channel(ch) for ch in cv2.split(img)] + out = cv2.merge(channels) + return out + + +def rotate_func(img, degree, fill=(0, 0, 0)): + ''' + like PIL, rotate by degree, not radians + ''' + H, W = img.shape[0], img.shape[1] + center = W / 2, H / 2 + M = cv2.getRotationMatrix2D(center, degree, 1) + out = cv2.warpAffine(img, M, (W, H), borderValue=fill) + return out + + +def solarize_func(img, thresh=128): + ''' + same output as PIL.ImageOps.posterize + ''' + table = np.array([el if el < thresh else 255 - el for el in range(256)]) + table = table.clip(0, 255).astype(np.uint8) + out = table[img] + return out + + +def color_func(img, factor): + # same output as PIL.ImageEnhance.Color + M = ( + np.float32([[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], + [-0.299, -0.299, 0.701]]) * factor + + np.float32([[0.114], [0.587], [0.299]])) + out = np.matmul(img, M).clip(0, 255).astype(np.uint8) + return out + + +def contrast_func(img, factor): + """ + same output as PIL.ImageEnhance.Contrast + """ + mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299])) + table = np.array([(el - mean) * factor + mean + for el in range(256)]).clip(0, 255).astype(np.uint8) + out = table[img] + return out + + +def brightness_func(img, factor): + ''' + same output as PIL.ImageEnhance.Contrast + ''' + table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype( + np.uint8) + out = table[img] + return out + + +def sharpness_func(img, factor): + ''' + The differences the this result and PIL are all on the 4 boundaries, the center + areas are same + ''' + kernel = np.ones((3, 3), dtype=np.float32) + kernel[1][1] = 5 + kernel /= 13 + degenerate = cv2.filter2D(img, -1, kernel) + if factor == 0.0: + out = degenerate + elif factor == 1.0: + out = img + else: + out = img.astype(np.float32) + degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :] + out[1:-1, 1:-1, :] = degenerate + factor * ( + out[1:-1, 1:-1, :] - degenerate) + out = out.astype(np.uint8) + return out + + +def shear_x_func(img, factor, fill=(0, 0, 0)): + H, W = img.shape[0], img.shape[1] + M = np.float32([[1, factor, 0], [0, 1, 0]]) + out = cv2.warpAffine( + img, M, (W, H), borderValue=fill, + flags=cv2.INTER_LINEAR).astype(np.uint8) + return out + + +def translate_x_func(img, offset, fill=(0, 0, 0)): + ''' + same output as PIL.Image.transform + ''' + H, W = img.shape[0], img.shape[1] + M = np.float32([[1, 0, -offset], [0, 1, 0]]) + out = cv2.warpAffine( + img, M, (W, H), borderValue=fill, + flags=cv2.INTER_LINEAR).astype(np.uint8) + return out + + +def translate_y_func(img, offset, fill=(0, 0, 0)): + ''' + same output as PIL.Image.transform + ''' + H, W = img.shape[0], img.shape[1] + M = np.float32([[1, 0, 0], [0, 1, -offset]]) + out = cv2.warpAffine( + img, M, (W, H), borderValue=fill, + flags=cv2.INTER_LINEAR).astype(np.uint8) + return out + + +def posterize_func(img, bits): + ''' + same output as PIL.ImageOps.posterize + ''' + out = np.bitwise_and(img, np.uint8(255 << (8 - bits))) + return out + + +def shear_y_func(img, factor, fill=(0, 0, 0)): + H, W = img.shape[0], img.shape[1] + M = np.float32([[1, 0, 0], [factor, 1, 0]]) + out = cv2.warpAffine( + img, M, (W, H), borderValue=fill, + flags=cv2.INTER_LINEAR).astype(np.uint8) + return out + + +def cutout_func(img, pad_size, replace=(0, 0, 0)): + replace = np.array(replace, dtype=np.uint8) + H, W = img.shape[0], img.shape[1] + rh, rw = np.random.random(2) + pad_size = pad_size // 2 + ch, cw = int(rh * H), int(rw * W) + x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H) + y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W) + out = img.copy() + out[x1:x2, y1:y2, :] = replace + return out + + +# level to args +def enhance_level_to_args(MAX_LEVEL): + + def level_to_args(level): + return ((level / MAX_LEVEL) * 1.8 + 0.1, ) + + return level_to_args + + +def shear_level_to_args(MAX_LEVEL, replace_value): + + def level_to_args(level): + level = (level / MAX_LEVEL) * 0.3 + if np.random.random() > 0.5: + level = -level + return level, replace_value + + return level_to_args + + +def translate_level_to_args(translate_const, MAX_LEVEL, replace_value): + + def level_to_args(level): + level = (level / MAX_LEVEL) * float(translate_const) + if np.random.random() > 0.5: + level = -level + return (level, replace_value) + + return level_to_args + + +def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value): + + def level_to_args(level): + level = int((level / MAX_LEVEL) * cutout_const) + return (level, replace_value) + + return level_to_args + + +def solarize_level_to_args(MAX_LEVEL): + + def level_to_args(level): + level = int((level / MAX_LEVEL) * 256) + return (level, ) + + return level_to_args + + +def none_level_to_args(level): + return () + + +def posterize_level_to_args(MAX_LEVEL): + + def level_to_args(level): + level = int((level / MAX_LEVEL) * 4) + return (level, ) + + return level_to_args + + +def rotate_level_to_args(MAX_LEVEL, replace_value): + + def level_to_args(level): + level = (level / MAX_LEVEL) * 30 + if np.random.random() < 0.5: + level = -level + return (level, replace_value) + + return level_to_args + + +func_dict = { + 'Identity': identity_func, + 'AutoContrast': autocontrast_func, + 'Equalize': equalize_func, + 'Rotate': rotate_func, + 'Solarize': solarize_func, + 'Color': color_func, + 'Contrast': contrast_func, + 'Brightness': brightness_func, + 'Sharpness': sharpness_func, + 'ShearX': shear_x_func, + 'TranslateX': translate_x_func, + 'TranslateY': translate_y_func, + 'Posterize': posterize_func, + 'ShearY': shear_y_func, +} + +translate_const = 10 +MAX_LEVEL = 10 +replace_value = (128, 128, 128) +arg_dict = { + 'Identity': + none_level_to_args, + 'AutoContrast': + none_level_to_args, + 'Equalize': + none_level_to_args, + 'Rotate': + rotate_level_to_args(MAX_LEVEL, replace_value), + 'Solarize': + solarize_level_to_args(MAX_LEVEL), + 'Color': + enhance_level_to_args(MAX_LEVEL), + 'Contrast': + enhance_level_to_args(MAX_LEVEL), + 'Brightness': + enhance_level_to_args(MAX_LEVEL), + 'Sharpness': + enhance_level_to_args(MAX_LEVEL), + 'ShearX': + shear_level_to_args(MAX_LEVEL, replace_value), + 'TranslateX': + translate_level_to_args(translate_const, MAX_LEVEL, replace_value), + 'TranslateY': + translate_level_to_args(translate_const, MAX_LEVEL, replace_value), + 'Posterize': + posterize_level_to_args(MAX_LEVEL), + 'ShearY': + shear_level_to_args(MAX_LEVEL, replace_value), +} + + +class RandomAugment(object): + + def __init__(self, N=2, M=10, isPIL=False, augs=[]): + self.N = N + self.M = M + self.isPIL = isPIL + if augs: + self.augs = augs + else: + self.augs = list(arg_dict.keys()) + + def get_random_ops(self): + sampled_ops = np.random.choice(self.augs, self.N) + return [(op, 0.5, self.M) for op in sampled_ops] + + def __call__(self, img): + if self.isPIL: + img = np.array(img) + ops = self.get_random_ops() + for name, prob, level in ops: + if np.random.random() > prob: + continue + args = arg_dict[name](level) + img = func_dict[name](img, *args) + return img diff --git a/modelscope/preprocessors/ofa/visual_entailment.py b/modelscope/preprocessors/ofa/visual_entailment.py new file mode 100644 index 00000000..72e88d75 --- /dev/null +++ b/modelscope/preprocessors/ofa/visual_entailment.py @@ -0,0 +1,62 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict + +import torch +from PIL import Image +from torchvision import transforms + +from modelscope.preprocessors.image import load_image +from .base import OfaBasePreprocessor + + +class OfaVisualEntailmentPreprocessor(OfaBasePreprocessor): + + def __init__(self, cfg, model_dir): + """preprocess the data via the vocab.txt from the `model_dir` path + + Args: + cfg(modelscope.utils.config.ConfigDict) : model config + model_dir (str): model path + """ + super(OfaVisualEntailmentPreprocessor, self).__init__(cfg, model_dir) + # Initialize transform + self.patch_resize_transform = transforms.Compose([ + lambda image: image.convert('RGB'), + transforms.Resize((self.patch_image_size, self.patch_image_size), + interpolation=Image.BICUBIC), + transforms.ToTensor(), + transforms.Normalize(mean=self.mean, std=self.std), + ]) + + def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: + image = data['image'] if isinstance( + data['image'], Image.Image) else load_image(data['image']) + patch_image = self.patch_resize_transform(image) + if 'text2' not in data: + hypothesis = self.pre_caption(data['text'], self.max_src_length) + prompt = self.cfg.model.get('prompt', + ' does the image describe " {} "?') + text = prompt.format(hypothesis) + else: + assert 'text' in data, f'text must be in the input {data.keys()}' + caption = self.pre_caption(data['text2'], self.max_src_length) + hypothesis = self.pre_caption(data['text'], self.max_src_length) + prompt = self.cfg.model.get( + 'prompt', ' can image and text1 " {} " imply text2 " {} "?') + text = prompt.format(caption, hypothesis) + inputs = self.get_inputs(text) + if self.prompt_type == 'none': + decoder_prompt = self.bos_item + elif self.prompt_type == 'src': + decoder_prompt = inputs + elif self.prompt_type == 'prev_output': + decoder_prompt = inputs[:-1] + else: + raise NotImplementedError + sample = { + 'source': inputs, + 'patch_image': patch_image, + 'patch_mask': torch.tensor([True]), + 'decoder_prompt': decoder_prompt, + } + return sample diff --git a/modelscope/preprocessors/ofa/visual_grounding.py b/modelscope/preprocessors/ofa/visual_grounding.py new file mode 100644 index 00000000..eebc4cf2 --- /dev/null +++ b/modelscope/preprocessors/ofa/visual_grounding.py @@ -0,0 +1,50 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict + +import torch +from PIL import Image +from torchvision import transforms + +from modelscope.preprocessors.image import load_image +from .base import OfaBasePreprocessor + + +class OfaVisualGroundingPreprocessor(OfaBasePreprocessor): + + def __init__(self, cfg, model_dir): + """preprocess the data via the vocab.txt from the `model_dir` path + + Args: + cfg(modelscope.utils.config.ConfigDict) : model config + model_dir (str): model path + """ + super(OfaVisualGroundingPreprocessor, self).__init__(cfg, model_dir) + # Initialize transform + self.patch_resize_transform = transforms.Compose([ + lambda image: image.convert('RGB'), + transforms.Resize((self.patch_image_size, self.patch_image_size), + interpolation=Image.BICUBIC), + transforms.ToTensor(), + transforms.Normalize(mean=self.mean, std=self.std), + ]) + + def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: + image = data['image'] if isinstance( + data['image'], Image.Image) else load_image(data['image']) + w, h = image.size + patch_image = self.patch_resize_transform(image) + w_resize_ratio = torch.tensor(self.patch_image_size / w) + h_resize_ratio = torch.tensor(self.patch_image_size / h) + src_caption = self.pre_caption(data['text'], self.max_src_length) + prompt = self.cfg.model.get( + 'prompt', ' which region does the text " {} " describe?') + text = prompt.format(src_caption) + src_item = self.get_inputs(text) + sample = { + 'source': src_item, + 'patch_image': patch_image, + 'patch_mask': torch.tensor([True]), + 'w_resize_ratio': w_resize_ratio, + 'h_resize_ratio': h_resize_ratio, + } + return sample diff --git a/modelscope/preprocessors/ofa/visual_question_answering.py b/modelscope/preprocessors/ofa/visual_question_answering.py new file mode 100644 index 00000000..b11af9f6 --- /dev/null +++ b/modelscope/preprocessors/ofa/visual_question_answering.py @@ -0,0 +1,52 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict + +import torch +from PIL import Image +from torchvision import transforms + +from modelscope.preprocessors.image import load_image +from .base import OfaBasePreprocessor + + +class OfaVisualQuestionAnsweringPreprocessor(OfaBasePreprocessor): + + def __init__(self, cfg, model_dir): + """preprocess the data via the vocab.txt from the `model_dir` path + + Args: + cfg(modelscope.utils.config.ConfigDict) : model config + model_dir (str): model path + """ + super(OfaVisualQuestionAnsweringPreprocessor, + self).__init__(cfg, model_dir) + # Initialize transform + self.patch_resize_transform = transforms.Compose([ + lambda image: image.convert('RGB'), + transforms.Resize((self.patch_image_size, self.patch_image_size), + interpolation=Image.BICUBIC), + transforms.ToTensor(), + transforms.Normalize(mean=self.mean, std=self.std), + ]) + + def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: + image = data['image'] if isinstance( + data['image'], Image.Image) else load_image(data['image']) + patch_image = self.patch_resize_transform(image) + text = ' {}'.format(data['text']) + inputs = self.get_inputs(text) + if self.prompt_type == 'none': + decoder_prompt = self.bos_item + elif self.prompt_type == 'src': + decoder_prompt = inputs + elif self.prompt_type == 'prev_output': + decoder_prompt = inputs[:-1] + else: + raise NotImplementedError + sample = { + 'source': inputs, + 'patch_image': patch_image, + 'patch_mask': torch.tensor([True]), + 'decoder_prompt': decoder_prompt, + } + return sample diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 4bb6ba5d..c4aace7e 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -85,6 +85,7 @@ class MultiModalTasks(object): multi_modal_embedding = 'multi-modal-embedding' generative_multi_modal_embedding = 'generative-multi-modal-embedding' visual_question_answering = 'visual-question-answering' + visual_entailment = 'visual-entailment' video_multi_modal_embedding = 'video-multi-modal-embedding' diff --git a/modelscope/utils/trie.py b/modelscope/utils/trie.py new file mode 100644 index 00000000..77f7e971 --- /dev/null +++ b/modelscope/utils/trie.py @@ -0,0 +1,29 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from collections import defaultdict + + +class TreeNode: + + def __init__(self): + self.child = defaultdict(TreeNode) + + +class Trie: + + def __init__(self, eos): + self.root = TreeNode() + self.eos = eos + + def insert(self, word): + cur = self.root + for c in word: + cur = cur.child[c] + + def get_next_layer(self, word): + cur = self.root + for c in word: + cur = cur.child.get(c) + if cur is None: + return [self.eos] + return list(cur.child.keys()) diff --git a/tests/pipelines/test_image_captioning.py b/tests/pipelines/test_image_captioning.py deleted file mode 100644 index fc029146..00000000 --- a/tests/pipelines/test_image_captioning.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -import unittest - -from modelscope.outputs import OutputKeys -from modelscope.pipelines import pipeline -from modelscope.utils.constant import Tasks -from modelscope.utils.test_utils import test_level - - -class ImageCaptionTest(unittest.TestCase): - - @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') - def test_run(self): - img_captioning = pipeline( - Tasks.image_captioning, - model='damo/ofa_image-caption_coco_distilled_en') - result = img_captioning('data/test/images/image_captioning.png') - print(result[OutputKeys.CAPTION]) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/pipelines/test_ofa_tasks.py b/tests/pipelines/test_ofa_tasks.py new file mode 100644 index 00000000..2c494e40 --- /dev/null +++ b/tests/pipelines/test_ofa_tasks.py @@ -0,0 +1,179 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.models import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class OfaTasksTest(unittest.TestCase): + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_image_captioning_with_model(self): + model = Model.from_pretrained( + 'damo/ofa_image-caption_coco_distilled_en') + img_captioning = pipeline( + task=Tasks.image_captioning, + model=model, + ) + result = img_captioning( + {'image': 'data/test/images/image_captioning.png'}) + print(result[OutputKeys.CAPTION]) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_image_captioning_with_name(self): + img_captioning = pipeline( + Tasks.image_captioning, + model='damo/ofa_image-caption_coco_distilled_en') + result = img_captioning( + {'image': 'data/test/images/image_captioning.png'}) + print(result[OutputKeys.CAPTION]) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_image_classification_with_model(self): + model = Model.from_pretrained( + 'damo/ofa_image-classification_imagenet_large_en') + ofa_pipe = pipeline(Tasks.image_classification, model=model) + image = 'data/test/images/image_classification.png' + input = {'image': image} + result = ofa_pipe(input) + print(result) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_image_classification_with_name(self): + ofa_pipe = pipeline( + Tasks.image_classification, + model='damo/ofa_image-classification_imagenet_large_en') + image = 'data/test/images/image_classification.png' + input = {'image': image} + result = ofa_pipe(input) + print(result) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_summarization_with_model(self): + model = Model.from_pretrained( + 'damo/ofa_summarization_gigaword_large_en') + ofa_pipe = pipeline(Tasks.summarization, model=model) + text = 'five-time world champion michelle kwan withdrew' + \ + 'from the #### us figure skating championships on wednesday ,' + \ + ' but will petition us skating officials for the chance to ' + \ + 'compete at the #### turin olympics .' + input = {'text': text} + result = ofa_pipe(input) + print(result) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_summarization_with_name(self): + ofa_pipe = pipeline( + Tasks.summarization, + model='damo/ofa_summarization_gigaword_large_en') + text = 'five-time world champion michelle kwan withdrew' + \ + 'from the #### us figure skating championships on wednesday ,' + \ + ' but will petition us skating officials for the chance to ' +\ + 'compete at the #### turin olympics .' + input = {'text': text} + result = ofa_pipe(input) + print(result) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_text_classification_with_model(self): + model = Model.from_pretrained( + 'damo/ofa_text-classification_mnli_large_en') + ofa_pipe = pipeline(Tasks.text_classification, model=model) + text = 'One of our number will carry out your instructions minutely.' + text2 = 'A member of my team will execute your orders with immense precision.' + input = {'text': text, 'text2': text2} + result = ofa_pipe(input) + print(result) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_text_classification_with_name(self): + ofa_pipe = pipeline( + Tasks.text_classification, + model='damo/ofa_text-classification_mnli_large_en') + text = 'One of our number will carry out your instructions minutely.' + text2 = 'A member of my team will execute your orders with immense precision.' + input = {'text': text, 'text2': text2} + result = ofa_pipe(input) + print(result) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_visual_entailment_with_model(self): + model = Model.from_pretrained( + 'damo/ofa_visual-entailment_snli-ve_large_en') + ofa_pipe = pipeline(Tasks.visual_entailment, model=model) + image = 'data/test/images/dogs.jpg' + text = 'there are two birds.' + input = {'image': image, 'text': text} + result = ofa_pipe(input) + print(result) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_visual_entailment_with_name(self): + ofa_pipe = pipeline( + Tasks.visual_entailment, + model='damo/ofa_visual-entailment_snli-ve_large_en') + image = 'data/test/images/dogs.jpg' + text = 'there are two birds.' + input = {'image': image, 'text': text} + result = ofa_pipe(input) + print(result) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_visual_grounding_with_model(self): + model = Model.from_pretrained( + 'damo/ofa_visual-grounding_refcoco_large_en') + ofa_pipe = pipeline(Tasks.visual_grounding, model=model) + image = 'data/test/images/visual_grounding.png' + text = 'a blue turtle-like pokemon with round head' + input = {'image': image, 'text': text} + result = ofa_pipe(input) + print(result) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_visual_grounding_with_name(self): + ofa_pipe = pipeline( + Tasks.visual_grounding, + model='damo/ofa_visual-grounding_refcoco_large_en') + image = 'data/test/images/visual_grounding.png' + text = 'a blue turtle-like pokemon with round head' + input = {'image': image, 'text': text} + result = ofa_pipe(input) + print(result) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_visual_question_answering_with_model(self): + from modelscope.preprocessors.multi_modal import OfaPreprocessor + model = Model.from_pretrained( + 'damo/ofa_visual-question-answering_pretrain_large_en') + preprocessor = OfaPreprocessor(model_dir=model.model_dir) + ofa_pipe = pipeline( + Tasks.visual_question_answering, + model=model, + preprocessor=preprocessor) + image = 'data/test/images/visual_question_answering.png' + text = 'what is grown on the plant?' + input = {'image': image, 'text': text} + result = ofa_pipe(input) + print(result) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_visual_question_answering_with_name(self): + from modelscope.preprocessors.multi_modal import OfaPreprocessor + model = 'damo/ofa_visual-question-answering_pretrain_large_en' + preprocessor = OfaPreprocessor(model_dir=model) + ofa_pipe = pipeline( + Tasks.visual_question_answering, + model=model, + preprocessor=preprocessor) + image = 'data/test/images/visual_question_answering.png' + text = 'what is grown on the plant?' + input = {'image': image, 'text': text} + result = ofa_pipe(input) + print(result) + + +if __name__ == '__main__': + unittest.main()