mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 12:10:09 +01:00
[to #42322933] Add ofa-text-to-image-synthesis to maas lib
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9590940
This commit is contained in:
@@ -203,6 +203,7 @@ class Preprocessors(object):
|
||||
|
||||
# multi-modal
|
||||
ofa_image_caption = 'ofa-image-caption'
|
||||
ofa_text_to_image_synthesis = 'ofa-text-to-image-synthesis'
|
||||
mplug_visual_question_answering = 'mplug-visual-question-answering'
|
||||
|
||||
|
||||
|
||||
@@ -20,7 +20,9 @@ else:
|
||||
'mmr': ['VideoCLIPForMultiModalEmbedding'],
|
||||
'mplug_for_visual_question_answering':
|
||||
['MPlugForVisualQuestionAnswering'],
|
||||
'ofa_for_all_tasks': ['OfaForAllTasks']
|
||||
'ofa_for_all_tasks': ['OfaForAllTasks'],
|
||||
'ofa_for_text_to_image_synthesis_model':
|
||||
['OfaForTextToImageSynthesis']
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.cuda
|
||||
from PIL import Image
|
||||
from taming.models.vqgan import GumbelVQ, VQModel
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import Model
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.models.multi_modal.ofa import OFAModel, OFATokenizer
|
||||
from modelscope.models.multi_modal.ofa.generate import sequence_generator as sg
|
||||
from modelscope.models.multi_modal.ofa.generate.search import Sampling
|
||||
from modelscope.models.multi_modal.ofa.generate.utils import move_to_device
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
__all__ = ['OfaForTextToImageSynthesis']
|
||||
|
||||
|
||||
def custom_to_pil(x):
|
||||
x = x.detach().cpu()
|
||||
x = torch.clamp(x, -1., 1.)
|
||||
x = (x + 1.) / 2.
|
||||
x = x.permute(1, 2, 0).numpy()
|
||||
x = (255 * x).astype(np.uint8)
|
||||
x = Image.fromarray(x)
|
||||
if not x.mode == 'RGB':
|
||||
x = x.convert('RGB')
|
||||
return x
|
||||
|
||||
|
||||
def load_vqgan(config, ckpt_path=None, is_gumbel=False):
|
||||
if is_gumbel:
|
||||
model = GumbelVQ(**config['model']['params'])
|
||||
else:
|
||||
model = VQModel(**config['model']['params'])
|
||||
if ckpt_path is not None:
|
||||
sd = torch.load(ckpt_path, map_location='cpu')['state_dict']
|
||||
missing, unexpected = model.load_state_dict(sd, strict=False)
|
||||
return model.eval()
|
||||
|
||||
|
||||
@MODELS.register_module(Tasks.text_to_image_synthesis, module_name=Models.ofa)
|
||||
class OfaForTextToImageSynthesis(Model):
|
||||
|
||||
def __init__(self, model_dir, *args, **kwargs):
|
||||
super().__init__(model_dir=model_dir, *args, **kwargs)
|
||||
# Initialize ofa
|
||||
model = OFAModel.from_pretrained(model_dir)
|
||||
self.model = model.module if hasattr(model, 'module') else model
|
||||
self.tokenizer = OFATokenizer.from_pretrained(model_dir)
|
||||
self.tokenizer.add_tokens(['<code_{}>'.format(i) for i in range(8192)])
|
||||
self.tokenizer.add_tokens(['<bin_{}>'.format(i) for i in range(1000)])
|
||||
self._device = torch.device('cuda') if torch.cuda.is_available() \
|
||||
else torch.device('cpu')
|
||||
self.model.to(self._device)
|
||||
|
||||
# Initialize vqgan
|
||||
vqgan_config = json.load(
|
||||
open(os.path.join(model_dir, 'vqgan_config.json')))
|
||||
self.vqgan_model = load_vqgan(
|
||||
vqgan_config,
|
||||
ckpt_path=os.path.join(model_dir, 'vqgan_model.ckpt'),
|
||||
is_gumbel=True).to(self._device)
|
||||
# Initialize generator
|
||||
sampling = Sampling(self.tokenizer, sampling_topp=0.9)
|
||||
sg_args = {
|
||||
'tokenizer': self.tokenizer,
|
||||
'beam_size': 1,
|
||||
'max_len_b': 1024,
|
||||
'min_len': 1024,
|
||||
'search_strategy': sampling,
|
||||
'gen_code': True,
|
||||
'constraint_range': '50265,58457'
|
||||
}
|
||||
self.generator = sg.SequenceGenerator(**sg_args)
|
||||
|
||||
def forward(self, input: Dict[str, Any]):
|
||||
input = move_to_device(input, self._device)
|
||||
gen_output = self.generator.generate([self.model], input)
|
||||
gen_tokens = gen_output[0][0]['tokens'][:-1]
|
||||
codes = gen_tokens.view(1, 32, 32) - 50265
|
||||
quant_b = self.vqgan_model.quantize.get_codebook_entry(
|
||||
codes.view(-1),
|
||||
list(codes.size()) + [self.vqgan_model.quantize.embedding_dim])
|
||||
dec = self.vqgan_model.decode(quant_b)[0]
|
||||
return custom_to_pil(dec)
|
||||
@@ -1,11 +1,13 @@
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.multi_modal import OfaForTextToImageSynthesis
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Model, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors import OfaPreprocessor, Preprocessor
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
@@ -17,7 +19,10 @@ logger = get_logger()
|
||||
module_name=Pipelines.text_to_image_synthesis)
|
||||
class TextToImageSynthesisPipeline(Pipeline):
|
||||
|
||||
def __init__(self, model: str, **kwargs):
|
||||
def __init__(self,
|
||||
model: str,
|
||||
preprocessor: Optional[Preprocessor] = None,
|
||||
**kwargs):
|
||||
"""
|
||||
use `model` and `preprocessor` to create a kws pipeline for prediction
|
||||
Args:
|
||||
@@ -31,13 +36,20 @@ class TextToImageSynthesisPipeline(Pipeline):
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'expecting a Model instance or str, but get {type(model)}.')
|
||||
if preprocessor is None and isinstance(pipe_model,
|
||||
OfaForTextToImageSynthesis):
|
||||
preprocessor = OfaPreprocessor(pipe_model.model_dir)
|
||||
super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs)
|
||||
|
||||
super().__init__(model=pipe_model, **kwargs)
|
||||
|
||||
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||
return input
|
||||
def preprocess(self, input: Input, **preprocess_params) -> Dict[str, Any]:
|
||||
if self.preprocessor is not None:
|
||||
return self.preprocessor(input, **preprocess_params)
|
||||
else:
|
||||
return input
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if isinstance(self.model, OfaForTextToImageSynthesis):
|
||||
return self.model(input)
|
||||
return self.model.generate(input)
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
@@ -23,6 +23,8 @@ __all__ = [
|
||||
|
||||
@PREPROCESSORS.register_module(
|
||||
Fields.multi_modal, module_name=Preprocessors.ofa_image_caption)
|
||||
@PREPROCESSORS.register_module(
|
||||
Fields.multi_modal, module_name=Preprocessors.ofa_text_to_image_synthesis)
|
||||
class OfaPreprocessor(Preprocessor):
|
||||
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
@@ -40,7 +42,8 @@ class OfaPreprocessor(Preprocessor):
|
||||
Tasks.visual_entailment: OfaVisualEntailmentPreprocessor,
|
||||
Tasks.image_classification: OfaImageClassificationPreprocessor,
|
||||
Tasks.text_classification: OfaTextClassificationPreprocessor,
|
||||
Tasks.summarization: OfaSummarizationPreprocessor
|
||||
Tasks.summarization: OfaSummarizationPreprocessor,
|
||||
Tasks.text_to_image_synthesis: OfaTextToImageSynthesisPreprocessor
|
||||
}
|
||||
input_key_mapping = {
|
||||
Tasks.image_captioning: ['image'],
|
||||
@@ -50,6 +53,7 @@ class OfaPreprocessor(Preprocessor):
|
||||
Tasks.visual_grounding: ['image', 'text'],
|
||||
Tasks.visual_question_answering: ['image', 'text'],
|
||||
Tasks.visual_entailment: ['image', 'text', 'text2'],
|
||||
Tasks.text_to_image_synthesis: ['text']
|
||||
}
|
||||
model_dir = model_dir if osp.exists(model_dir) else snapshot_download(
|
||||
model_dir)
|
||||
|
||||
@@ -3,6 +3,7 @@ from .image_captioning import OfaImageCaptioningPreprocessor
|
||||
from .image_classification import OfaImageClassificationPreprocessor
|
||||
from .summarization import OfaSummarizationPreprocessor
|
||||
from .text_classification import OfaTextClassificationPreprocessor
|
||||
from .text_to_image_synthesis import OfaTextToImageSynthesisPreprocessor
|
||||
from .visual_entailment import OfaVisualEntailmentPreprocessor
|
||||
from .visual_grounding import OfaVisualGroundingPreprocessor
|
||||
from .visual_question_answering import OfaVisualQuestionAnsweringPreprocessor
|
||||
|
||||
31
modelscope/preprocessors/ofa/text_to_image_synthesis.py
Normal file
31
modelscope/preprocessors/ofa/text_to_image_synthesis.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
|
||||
from .base import OfaBasePreprocessor
|
||||
|
||||
|
||||
class OfaTextToImageSynthesisPreprocessor(OfaBasePreprocessor):
|
||||
|
||||
def __init__(self, cfg, model_dir):
|
||||
"""preprocess the data via the vocab.txt from the `model_dir` path
|
||||
|
||||
Args:
|
||||
model_dir (str): model path
|
||||
"""
|
||||
super(OfaTextToImageSynthesisPreprocessor,
|
||||
self).__init__(cfg, model_dir)
|
||||
self.max_src_length = 64
|
||||
|
||||
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
source = data['text'].lower().strip().split()[:self.max_src_length]
|
||||
source = 'what is the complete image? caption: {}'.format(source)
|
||||
inputs = self.get_inputs(source)
|
||||
sample = {
|
||||
'source': inputs,
|
||||
'patch_images': None,
|
||||
'patch_masks': torch.tensor([False]),
|
||||
'code_masks': torch.tensor([False])
|
||||
}
|
||||
return sample
|
||||
@@ -6,6 +6,7 @@ pycocotools>=2.0.4
|
||||
# rough-score was just recently updated from 0.0.4 to 0.0.7
|
||||
# which introduced compatability issues that are being investigated
|
||||
rouge_score<=0.0.4
|
||||
taming-transformers-rom1504
|
||||
timm
|
||||
tokenizers
|
||||
torchvision
|
||||
|
||||
@@ -244,6 +244,25 @@ class OfaTasksTest(unittest.TestCase):
|
||||
result = ofa_pipe(input)
|
||||
print(result)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_run_with_text_to_image_synthesis_with_name(self):
|
||||
model = 'damo/ofa_text-to-image-synthesis_coco_large_en'
|
||||
ofa_pipe = pipeline(Tasks.text_to_image_synthesis, model=model)
|
||||
example = {'text': 'a bear in the water.'}
|
||||
result = ofa_pipe(example)
|
||||
result[OutputKeys.OUTPUT_IMG].save('result.png')
|
||||
print(f'Output written to {osp.abspath("result.png")}')
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_run_with_text_to_image_synthesis_with_model(self):
|
||||
model = Model.from_pretrained(
|
||||
'damo/ofa_text-to-image-synthesis_coco_large_en')
|
||||
ofa_pipe = pipeline(Tasks.text_to_image_synthesis, model=model)
|
||||
example = {'text': 'a bear in the water.'}
|
||||
result = ofa_pipe(example)
|
||||
result[OutputKeys.OUTPUT_IMG].save('result.png')
|
||||
print(f'Output written to {osp.abspath("result.png")}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user