mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 20:19:51 +01:00
[to #42322933]ofa文生图接入clip reranking后处理 & 修复预处理中的一个Bug
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9880918
This commit is contained in:
@@ -6,17 +6,30 @@ import numpy as np
|
||||
import torch
|
||||
import torch.cuda
|
||||
from PIL import Image
|
||||
from pkg_resources import packaging
|
||||
from taming.models.vqgan import GumbelVQ, VQModel
|
||||
from torchvision.transforms import (CenterCrop, Compose, Normalize, Resize,
|
||||
ToTensor)
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import Model
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.models.multi_modal.mmr.models.module_clip import CLIP
|
||||
from modelscope.models.multi_modal.mmr.models.tokenization_clip import \
|
||||
SimpleTokenizer as ClipTokenizer
|
||||
from modelscope.models.multi_modal.ofa import OFAModel, OFATokenizer
|
||||
from modelscope.models.multi_modal.ofa.generate import sequence_generator as sg
|
||||
from modelscope.models.multi_modal.ofa.generate.search import Sampling
|
||||
from modelscope.models.multi_modal.ofa.generate.utils import move_to_device
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
try:
|
||||
from torchvision.transforms import InterpolationMode
|
||||
|
||||
BICUBIC = InterpolationMode.BICUBIC
|
||||
except ImportError:
|
||||
BICUBIC = Image.BICUBIC
|
||||
|
||||
__all__ = ['OfaForTextToImageSynthesis']
|
||||
|
||||
|
||||
@@ -43,6 +56,74 @@ def load_vqgan(config, ckpt_path=None, is_gumbel=False):
|
||||
return model.eval()
|
||||
|
||||
|
||||
def build_clip_model(model_path):
|
||||
state_dict = torch.load(model_path, map_location='cpu').state_dict()
|
||||
vit = 'visual.proj' in state_dict
|
||||
if vit:
|
||||
vision_width = state_dict['visual.conv1.weight'].shape[0]
|
||||
vision_layers = len([
|
||||
k for k in state_dict.keys()
|
||||
if k.startswith('visual.') and k.endswith('.attn.in_proj_weight')
|
||||
])
|
||||
vision_patch_size = state_dict['visual.conv1.weight'].shape[-1]
|
||||
grid_size = round(
|
||||
(state_dict['visual.positional_embedding'].shape[0] - 1)**0.5)
|
||||
image_resolution = vision_patch_size * grid_size
|
||||
else:
|
||||
counts: list = [
|
||||
len(
|
||||
set(
|
||||
k.split('.')[2] for k in state_dict
|
||||
if k.startswith(f'visual.layer{b}')))
|
||||
for b in [1, 2, 3, 4]
|
||||
]
|
||||
vision_layers = tuple(counts)
|
||||
vision_width = state_dict['visual.layer1.0.conv1.weight'].shape[0]
|
||||
output_width = round(
|
||||
(state_dict['visual.attnpool.positional_embedding'].shape[0]
|
||||
- 1)**0.5)
|
||||
vision_patch_size = None
|
||||
assert output_width**2 + 1 == state_dict[
|
||||
'visual.attnpool.positional_embedding'].shape[0]
|
||||
image_resolution = output_width * 32
|
||||
|
||||
embed_dim = state_dict['text_projection'].shape[1]
|
||||
context_length = state_dict['positional_embedding'].shape[0]
|
||||
vocab_size = state_dict['token_embedding.weight'].shape[0]
|
||||
transformer_width = state_dict['ln_final.weight'].shape[0]
|
||||
transformer_heads = transformer_width // 64
|
||||
transformer_layers = len(
|
||||
set(
|
||||
k.split('.')[2] for k in state_dict
|
||||
if k.startswith('transformer.resblocks')))
|
||||
|
||||
model = CLIP(embed_dim, image_resolution, vision_layers, vision_width,
|
||||
vision_patch_size, context_length, vocab_size,
|
||||
transformer_width, transformer_heads, transformer_layers)
|
||||
|
||||
for key in ['input_resolution', 'context_length', 'vocab_size']:
|
||||
if key in state_dict:
|
||||
del state_dict[key]
|
||||
|
||||
model.load_state_dict(state_dict)
|
||||
return model.eval()
|
||||
|
||||
|
||||
def _convert_image_to_rgb(image):
|
||||
return image.convert('RGB')
|
||||
|
||||
|
||||
def build_clip_transform(n_px):
|
||||
return Compose([
|
||||
Resize(n_px, interpolation=BICUBIC),
|
||||
CenterCrop(n_px),
|
||||
_convert_image_to_rgb,
|
||||
ToTensor(),
|
||||
Normalize((0.48145466, 0.4578275, 0.40821073),
|
||||
(0.26862954, 0.26130258, 0.27577711)),
|
||||
])
|
||||
|
||||
|
||||
@MODELS.register_module(Tasks.text_to_image_synthesis, module_name=Models.ofa)
|
||||
class OfaForTextToImageSynthesis(Model):
|
||||
|
||||
@@ -65,11 +146,23 @@ class OfaForTextToImageSynthesis(Model):
|
||||
vqgan_config,
|
||||
ckpt_path=os.path.join(model_dir, 'vqgan_model.ckpt'),
|
||||
is_gumbel=True).to(self._device)
|
||||
|
||||
# Initialize OpenAI clip
|
||||
|
||||
self.clip_tokenizer = ClipTokenizer(model_dir)
|
||||
self.clip_model = build_clip_model(
|
||||
os.path.join(model_dir, 'ViT-B-16.pt'))
|
||||
self.clip_preprocess = build_clip_transform(
|
||||
self.clip_model.visual.input_resolution)
|
||||
|
||||
self.clip_model.to(self._device)
|
||||
self.clip_model.eval()
|
||||
|
||||
# Initialize generator
|
||||
sampling = Sampling(self.tokenizer, sampling_topp=0.9)
|
||||
sg_args = {
|
||||
'tokenizer': self.tokenizer,
|
||||
'beam_size': 1,
|
||||
'beam_size': 2,
|
||||
'max_len_b': 1024,
|
||||
'min_len': 1024,
|
||||
'search_strategy': sampling,
|
||||
@@ -78,13 +171,68 @@ class OfaForTextToImageSynthesis(Model):
|
||||
}
|
||||
self.generator = sg.SequenceGenerator(**sg_args)
|
||||
|
||||
def clip_tokenize(self, texts, context_length=77, truncate=False):
|
||||
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
|
||||
sot_token = self.clip_tokenizer.encoder['<|startoftext|>']
|
||||
eot_token = self.clip_tokenizer.encoder['<|endoftext|>']
|
||||
all_tokens = [[sot_token] + self.clip_tokenizer.encode(text)
|
||||
+ [eot_token] for text in texts]
|
||||
if packaging.version.parse(
|
||||
torch.__version__) < packaging.version.parse('1.8.0'):
|
||||
result = torch.zeros(
|
||||
len(all_tokens), context_length, dtype=torch.long)
|
||||
else:
|
||||
result = torch.zeros(
|
||||
len(all_tokens), context_length, dtype=torch.int)
|
||||
|
||||
for i, tokens in enumerate(all_tokens):
|
||||
if len(tokens) > context_length:
|
||||
if truncate:
|
||||
tokens = tokens[:context_length]
|
||||
tokens[-1] = eot_token
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f'Input {texts[i]} is too long for context length {context_length}'
|
||||
)
|
||||
result[i, :len(tokens)] = torch.tensor(tokens)
|
||||
|
||||
return result
|
||||
|
||||
def forward(self, input: Dict[str, Any]):
|
||||
|
||||
text = input['samples'][0]['text']
|
||||
input = move_to_device(input, self._device)
|
||||
clip_text_input = self.clip_tokenize([text]).to(self._device)
|
||||
|
||||
gen_output = self.generator.generate([self.model], input)
|
||||
gen_tokens = gen_output[0][0]['tokens'][:-1]
|
||||
codes = gen_tokens.view(1, 32, 32) - 50265
|
||||
gen_tokens = torch.stack(
|
||||
[item['tokens'][:-1] for item in gen_output[0]], dim=0)
|
||||
codes = gen_tokens.view(-1, 32, 32) - 50265
|
||||
|
||||
quant_b = self.vqgan_model.quantize.get_codebook_entry(
|
||||
codes.view(-1),
|
||||
list(codes.size()) + [self.vqgan_model.quantize.embedding_dim])
|
||||
dec = self.vqgan_model.decode(quant_b)[0]
|
||||
return custom_to_pil(dec)
|
||||
imgs = self.vqgan_model.decode(quant_b)
|
||||
|
||||
sample_num = imgs.size()[0]
|
||||
pil_imgs = [custom_to_pil(imgs[i]) for i in range(sample_num)]
|
||||
|
||||
clip_image_input = torch.stack(
|
||||
[self.clip_preprocess(img) for img in pil_imgs],
|
||||
dim=0).to(self._device)
|
||||
|
||||
with torch.no_grad():
|
||||
hyp_image_features = self.clip_model.encode_image(clip_image_input)
|
||||
hyp_image_features /= hyp_image_features.norm(dim=-1, keepdim=True)
|
||||
text_features = self.clip_model.encode_text(clip_text_input)
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
ti_similarity = hyp_image_features @ text_features.T
|
||||
|
||||
sorted_score, ti_indices = torch.sort(
|
||||
ti_similarity.view(-1), descending=True)
|
||||
|
||||
pil_imgs_orderby_ti = [pil_imgs[index] for index in ti_indices]
|
||||
return pil_imgs_orderby_ti[0]
|
||||
|
||||
@@ -19,7 +19,8 @@ class OfaTextToImageSynthesisPreprocessor(OfaBasePreprocessor):
|
||||
self.max_src_length = 64
|
||||
|
||||
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
source = data['text'].lower().strip().split()[:self.max_src_length]
|
||||
source = ' '.join(
|
||||
data['text'].lower().strip().split()[:self.max_src_length])
|
||||
source = 'what is the complete image? caption: {}'.format(source)
|
||||
inputs = self.get_inputs(source)
|
||||
sample = {
|
||||
|
||||
Reference in New Issue
Block a user