From 690473ce85097dda677f3d0a4c2ce2c7e6d4383b Mon Sep 17 00:00:00 2001 From: "huizheng.hz" Date: Fri, 20 Oct 2023 16:10:54 +0800 Subject: [PATCH 1/2] add freeU model Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/14307648 * support sd21, sdxl --- modelscope/metainfo.py | 1 + .../models/multi_modal/freeu/__init__.py | 22 ++ .../multi_modal/freeu/free_lunch_utils.py | 331 ++++++++++++++++++ modelscope/pipelines/multi_modal/__init__.py | 4 +- .../text_to_image_freeu_pipeline.py | 138 ++++++++ tests/pipelines/test_text_to_image_freeu.py | 57 +++ 6 files changed, 552 insertions(+), 1 deletion(-) create mode 100644 modelscope/models/multi_modal/freeu/__init__.py create mode 100644 modelscope/models/multi_modal/freeu/free_lunch_utils.py create mode 100644 modelscope/pipelines/multi_modal/text_to_image_freeu_pipeline.py create mode 100644 tests/pipelines/test_text_to_image_freeu.py diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index ea56efb5..377ade9b 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -291,6 +291,7 @@ class Pipelines(object): image_denoise = 'nafnet-image-denoise' image_deblur = 'nafnet-image-deblur' image_editing = 'masactrl-image-editing' + freeu_stable_diffusion_text2image = 'freeu-stable-diffusion-text2image' person_image_cartoon = 'unet-person-image-cartoon' ocr_detection = 'resnet18-ocr-detection' table_recognition = 'dla34-table-recognition' diff --git a/modelscope/models/multi_modal/freeu/__init__.py b/modelscope/models/multi_modal/freeu/__init__.py new file mode 100644 index 00000000..3cd55cf3 --- /dev/null +++ b/modelscope/models/multi_modal/freeu/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .free_lunch_utils import register_free_upblock2d, register_free_crossattn_upblock2d +else: + _import_structure = { + 'free_lunch_utils': + ['register_free_upblock2d', 'register_free_crossattn_upblock2d'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/multi_modal/freeu/free_lunch_utils.py b/modelscope/models/multi_modal/freeu/free_lunch_utils.py new file mode 100644 index 00000000..eb5d191f --- /dev/null +++ b/modelscope/models/multi_modal/freeu/free_lunch_utils.py @@ -0,0 +1,331 @@ +# ------------------------------------------------------------------------ +# Modified from https://github.com/ChenyangSi/FreeU/blob/main/demo/free_lunch_utils.py +# Copyright (c) 2023 TencentARC. All Rights Reserved. +# ------------------------------------------------------------------------ + +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.fft as fft +from diffusers.utils import is_torch_version + + +def isinstance_str(x: object, cls_name: str): + """ + Checks whether x has any class *named* cls_name in its ancestry. + Doesn't require access to the class's implementation. + + Useful for patching! + """ + + for _cls in x.__class__.__mro__: + if _cls.__name__ == cls_name: + return True + + return False + + +def Fourier_filter(x, threshold, scale): + dtype = x.dtype + x = x.type(torch.float32) + # FFT + x_freq = fft.fftn(x, dim=(-2, -1)) + x_freq = fft.fftshift(x_freq, dim=(-2, -1)) + + B, C, H, W = x_freq.shape + mask = torch.ones((B, C, H, W)).cuda() + + crow, ccol = H // 2, W // 2 + mask[..., crow - threshold:crow + threshold, + ccol - threshold:ccol + threshold] = scale + x_freq = x_freq * mask + + # IFFT + x_freq = fft.ifftshift(x_freq, dim=(-2, -1)) + x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real + + x_filtered = x_filtered.type(dtype) + return x_filtered + + +def register_upblock2d(model): + + def up_forward(self): + + def forward(hidden_states, + res_hidden_states_tuple, + temb=None, + upsample_size=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], + dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version('>=', '1.11.0'): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + use_reentrant=False) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + return forward + + for i, upsample_block in enumerate(model.unet.up_blocks): + if isinstance_str(upsample_block, 'UpBlock2D'): + upsample_block.forward = up_forward(upsample_block) + + +def register_free_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2): + + def up_forward(self): + + def forward(hidden_states, + res_hidden_states_tuple, + temb=None, + upsample_size=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # --------------- FreeU code ----------------------- + # Only operate on the first two stages + if hidden_states.shape[1] == 1280: + hidden_states[:, :640] = hidden_states[:, :640] * self.b1 + res_hidden_states = Fourier_filter( + res_hidden_states, threshold=1, scale=self.s1) + if hidden_states.shape[1] == 640: + hidden_states[:, :320] = hidden_states[:, :320] * self.b2 + res_hidden_states = Fourier_filter( + res_hidden_states, threshold=1, scale=self.s2) + # --------------------------------------------------------- + + hidden_states = torch.cat([hidden_states, res_hidden_states], + dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version('>=', '1.11.0'): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + use_reentrant=False) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + return forward + + for i, upsample_block in enumerate(model.unet.up_blocks): + if isinstance_str(upsample_block, 'UpBlock2D'): + upsample_block.forward = up_forward(upsample_block) + setattr(upsample_block, 'b1', b1) + setattr(upsample_block, 'b2', b2) + setattr(upsample_block, 's1', s1) + setattr(upsample_block, 's2', s2) + + +def register_crossattn_upblock2d(model): + + def up_forward(self): + + def forward( + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], + dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = { + 'use_reentrant': False + } if is_torch_version('>=', '1.11.0') else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + None, # timestep + None, # class_labels + cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + **ckpt_kwargs, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + return forward + + for i, upsample_block in enumerate(model.unet.up_blocks): + if isinstance_str(upsample_block, 'CrossAttnUpBlock2D'): + upsample_block.forward = up_forward(upsample_block) + + +def register_free_crossattn_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2): + + def up_forward(self): + + def forward( + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # --------------- FreeU code ----------------------- + # Only operate on the first two stages + if hidden_states.shape[1] == 1280: + hidden_states[:, :640] = hidden_states[:, :640] * self.b1 + res_hidden_states = Fourier_filter( + res_hidden_states, threshold=1, scale=self.s1) + if hidden_states.shape[1] == 640: + hidden_states[:, :320] = hidden_states[:, :320] * self.b2 + res_hidden_states = Fourier_filter( + res_hidden_states, threshold=1, scale=self.s2) + # --------------------------------------------------------- + + hidden_states = torch.cat([hidden_states, res_hidden_states], + dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = { + 'use_reentrant': False + } if is_torch_version('>=', '1.11.0') else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + None, # timestep + None, # class_labels + cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + **ckpt_kwargs, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + return forward + + for i, upsample_block in enumerate(model.unet.up_blocks): + if isinstance_str(upsample_block, 'CrossAttnUpBlock2D'): + upsample_block.forward = up_forward(upsample_block) + setattr(upsample_block, 'b1', b1) + setattr(upsample_block, 'b2', b2) + setattr(upsample_block, 's1', s1) + setattr(upsample_block, 's2', s2) diff --git a/modelscope/pipelines/multi_modal/__init__.py b/modelscope/pipelines/multi_modal/__init__.py index b5316684..1faa261e 100644 --- a/modelscope/pipelines/multi_modal/__init__.py +++ b/modelscope/pipelines/multi_modal/__init__.py @@ -26,6 +26,7 @@ if TYPE_CHECKING: from .visual_question_answering_pipeline import VisualQuestionAnsweringPipeline from .video_question_answering_pipeline import VideoQuestionAnsweringPipeline from .videocomposer_pipeline import VideoComposerPipeline + from .text_to_image_freeu_pipeline import FreeUTextToImagePipeline else: _import_structure = { 'image_captioning_pipeline': ['ImageCaptioningPipeline'], @@ -53,7 +54,8 @@ else: ['SOONetVideoTemporalGroundingPipeline'], 'text_to_video_synthesis_pipeline': ['TextToVideoSynthesisPipeline'], 'multimodal_dialogue_pipeline': ['MultimodalDialoguePipeline'], - 'videocomposer_pipeline': ['VideoComposerPipeline'] + 'videocomposer_pipeline': ['VideoComposerPipeline'], + 'text_to_image_freeu_pipeline': ['FreeUTextToImagePipeline'] } import sys diff --git a/modelscope/pipelines/multi_modal/text_to_image_freeu_pipeline.py b/modelscope/pipelines/multi_modal/text_to_image_freeu_pipeline.py new file mode 100644 index 00000000..9300554c --- /dev/null +++ b/modelscope/pipelines/multi_modal/text_to_image_freeu_pipeline.py @@ -0,0 +1,138 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path +from typing import Any, Dict, Optional, Union + +import numpy as np +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models.multi_modal.freeu import ( + register_free_crossattn_upblock2d, register_free_upblock2d) +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + +__all__ = ['FreeUTextToImagePipeline'] + + +@PIPELINES.register_module( + Tasks.text_to_image_synthesis, + module_name=Pipelines.freeu_stable_diffusion_text2image) +class FreeUTextToImagePipeline(Pipeline): + + def __init__(self, model=str, preprocessor=None, **kwargs): + """ FreeU Text to Image Pipeline. + + Examples: + + >>> import cv2 + >>> from modelscope.pipelines import pipeline + >>> from modelscope.utils.constant import Tasks + + >>> prompt = "a photo of a running corgi" # prompt + >>> output_image_path = './result.png' + >>> inputs = {'prompt': prompt} + >>> + >>> pipe = pipeline( + >>> Tasks.text_to_image_synthesis, + >>> model='damo/multi-modal_freeu_stable_diffusion', + >>> base_model='AI-ModelScope/stable-diffusion-v1-5', + >>> ) + >>> + >>> output = pipe(inputs)['output_imgs'] + >>> cv2.imwrite(output_image_path, output) + >>> print('pipeline: the output image path is {}'.format(output_image_path)) + """ + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + + torch_dtype = kwargs.get('torch_dtype', torch.float32) + self._device = getattr( + kwargs, 'device', + torch.device('cuda' if torch.cuda.is_available() else 'cpu')) + base_model = kwargs.get( + 'base_model', 'AI-ModelScope/stable-diffusion-v1-5') # default 1.5 + self.freeu_params = kwargs.get('freeu_params', { + 'b1': 1.5, + 'b2': 1.6, + 's1': 0.9, + 's2': 0.2 + }) # default + + logger.info('load freeu stable diffusion text to image pipeline done') + self.pipeline = pipeline( + task=Tasks.text_to_image_synthesis, + model=base_model, + torch_dtype=torch_dtype, + device=self._device).pipeline + + def preprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]: + return inputs + + def forward(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]: + """ + Inputs Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + """ + if not isinstance(inputs, dict): + raise ValueError( + f'Expected the input to be a dictionary, but got {type(inputs)}' + ) + # -------- freeu block registration + register_free_upblock2d(self.pipeline, **self.freeu_params) + register_free_crossattn_upblock2d(self.pipeline, **self.freeu_params) + # -------- freeu block registration + + output = self.pipeline( + prompt=inputs.get('prompt'), + height=inputs.get('height'), + width=inputs.get('width'), + num_inference_steps=inputs.get('num_inference_steps', 50), + guidance_scale=inputs.get('guidance_scale', 7.5), + negative_prompt=inputs.get('negative_prompt'), + num_images_per_prompt=inputs.get('num_images_per_prompt', 1), + eta=inputs.get('eta', 0.0), + generator=inputs.get('generator'), + latents=inputs.get('latents'), + ).images[0] + + return {'output_tensor': output} + + def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]: + output_img = np.array(inputs['output_tensor']) + return {OutputKeys.OUTPUT_IMGS: output_img[:, :, ::-1]} diff --git a/tests/pipelines/test_text_to_image_freeu.py b/tests/pipelines/test_text_to_image_freeu.py new file mode 100644 index 00000000..7aebe318 --- /dev/null +++ b/tests/pipelines/test_text_to_image_freeu.py @@ -0,0 +1,57 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +import cv2 + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.multi_modal import FreeUTextToImagePipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class ImageEditingTest(unittest.TestCase): + + def setUp(self) -> None: + self.task = Tasks.text_to_image_synthesis + self.model_id = 'damo/multi-modal_freeu_stable_diffusion' + prompt = 'a photo of a running corgi' # prompt + self.inputs = {'prompt': prompt} + self.output_image_path = './result.png' + self.base_model = 'AI-ModelScope/stable-diffusion-v2-1' + self.freeu_params = { + 'b1': 1.4, + 'b2': 1.6, + 's1': 0.9, + 's2': 0.2 + } # for SD2.1 + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_by_direct_model_download(self): + cache_path = snapshot_download(self.model_id) + pipeline = FreeUTextToImagePipeline(cache_path) + pipeline.group_key = self.task + synthesized_img = pipeline( + input=self.inputs)[OutputKeys.OUTPUT_IMGS] # BGR + cv2.imwrite(self.output_image_path, synthesized_img) + print('FreeU pipeline: the synthesized image path is {}'.format( + self.output_image_path)) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_model_name(self): + pipeline_ins = pipeline( + task=Tasks.text_to_image_synthesis, + model=self.model_id, + base_model=self.base_model, + freeu_params=self.freeu_params) + synthesized_img = pipeline_ins( + self.inputs)[OutputKeys.OUTPUT_IMGS] # BGR + cv2.imwrite(self.output_image_path, synthesized_img) + print('FreeU pipeline: the synthesized image path is {}'.format( + self.output_image_path)) + + +if __name__ == '__main__': + unittest.main() From fb7328f4ec34cf5f4c6478f11846b740f24e1e1d Mon Sep 17 00:00:00 2001 From: "zhangyanzhao.zyz" Date: Fri, 20 Oct 2023 19:56:01 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E6=9B=B4=E6=96=B0sentence=20embedding=20mo?= =?UTF-8?q?del=EF=BC=8C=E6=94=AF=E6=8C=81gte=EF=BC=8Cbloom=20sentence=20em?= =?UTF-8?q?bedding?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/14375781 * fix linter * bloom embedding --- .../models/nlp/bert/sentence_embedding.py | 15 +- modelscope/models/nlp/bloom/__init__.py | 2 + .../models/nlp/bloom/sentence_embedding.py | 165 ++++++++++++++++++ .../nlp/sentence_embedding_preprocessor.py | 103 ++++++++++- tests/pipelines/test_sentence_embedding.py | 9 + 5 files changed, 286 insertions(+), 8 deletions(-) create mode 100644 modelscope/models/nlp/bloom/sentence_embedding.py diff --git a/modelscope/models/nlp/bert/sentence_embedding.py b/modelscope/models/nlp/bert/sentence_embedding.py index 92a9da50..b7df5ef9 100644 --- a/modelscope/models/nlp/bert/sentence_embedding.py +++ b/modelscope/models/nlp/bert/sentence_embedding.py @@ -1,6 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import torch +import torch.nn.functional as F from torch import nn from modelscope.metainfo import Models @@ -61,8 +62,9 @@ class BertForSentenceEmbedding(BertPreTrainedModel): def __init__(self, config, **kwargs): super().__init__(config) self.config = config - self.pooler_type = kwargs.get('pooler_type', 'cls') + self.pooler_type = kwargs.get('emb_pooler_type', 'cls') self.pooler = Pooler(self.pooler_type) + self.normalize = kwargs.get('normalize', False) setattr(self, self.base_model_prefix, BertModel(config, add_pooling_layer=False)) @@ -128,6 +130,8 @@ class BertForSentenceEmbedding(BertPreTrainedModel): output_hidden_states=output_hidden_states, return_dict=return_dict) outputs = self.pooler(outputs, attention_mask) + if self.normalize: + outputs = F.normalize(outputs, p=2, dim=-1) return outputs @classmethod @@ -142,8 +146,11 @@ class BertForSentenceEmbedding(BertPreTrainedModel): The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained """ model_dir = kwargs.get('model_dir') - model = super( - Model, - cls).from_pretrained(pretrained_model_name_or_path=model_dir) + model_kwargs = { + 'emb_pooler_type': kwargs.get('emb_pooler_type', 'cls'), + 'normalize': kwargs.get('normalize', False) + } + model = super(Model, cls).from_pretrained( + pretrained_model_name_or_path=model_dir, **model_kwargs) model.model_dir = model_dir return model diff --git a/modelscope/models/nlp/bloom/__init__.py b/modelscope/models/nlp/bloom/__init__.py index b0f04af7..24d7202d 100644 --- a/modelscope/models/nlp/bloom/__init__.py +++ b/modelscope/models/nlp/bloom/__init__.py @@ -6,10 +6,12 @@ from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: from .backbone import BloomModel from .text_generation import BloomForTextGeneration + from .sentence_embedding import BloomForSentenceEmbedding else: _import_structure = { 'backbone': ['BloomModel'], 'text_generation': ['BloomForTextGeneration'], + 'sentence_embedding': ['BloomForSentenceEmbedding'] } import sys sys.modules[__name__] = LazyImportModule( diff --git a/modelscope/models/nlp/bloom/sentence_embedding.py b/modelscope/models/nlp/bloom/sentence_embedding.py new file mode 100644 index 00000000..ec35db38 --- /dev/null +++ b/modelscope/models/nlp/bloom/sentence_embedding.py @@ -0,0 +1,165 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import torch +from transformers import BloomConfig +from transformers import BloomModel as BloomModelTransform + +from modelscope.metainfo import Models +from modelscope.models import MODELS, TorchModel +from modelscope.outputs import SentencEmbeddingModelOutput +from modelscope.utils.constant import Tasks + + +class DecoderPooler(torch.nn.Module): + """ + Parameter-free poolers to get the sentence embedding + 'last': the last token state. + 'weighted_mean': position weighted average of all token states. + """ + + def __init__(self, pooler_type): + super().__init__() + self.pooler_type = pooler_type + assert self.pooler_type in [ + 'last', 'weighted_mean' + ], 'unrecognized pooling type %s' % self.pooler_type + + def forward(self, outputs, attention_mask): + last_hidden = outputs.last_hidden_state + + if self.pooler_type in ['last']: + n, l, h = last_hidden.shape + + # Get shape [n] indices of the last token (i.e. the last token for each batch item) + # Any sequence where min == 1, we use the entire sequence lenth since argmin = 0 + values, indices = torch.min(attention_mask, 1, keepdim=False) + gather_indices = torch.where(values == 0, indices, + l) - 1 # Shape [n] + + # There are empty sequences, where the index would become -1 which will crash + gather_indices = torch.clamp(gather_indices, min=0) + + # Turn indices from shape [n] --> [n, 1, h] + gather_indices = gather_indices.unsqueeze(1).unsqueeze(1).expand( + n, 1, h) + + # Gather along the 1st dim (l) (n, l, h -> n, h) + pooled_output = torch.gather(last_hidden, 1, + gather_indices).squeeze(dim=1) + + elif self.pooler_type == 'weighted_mean': + input_mask_expanded = attention_mask.unsqueeze(-1).expand( + last_hidden.size()).float() + # last_hidden shape: bs, seq, hidden_dim + weights = ( + torch.arange(start=1, end=last_hidden.shape[1] + + 1).unsqueeze(0).unsqueeze(-1).expand( + last_hidden.size()).float().to( + last_hidden.device)) + assert weights.shape == last_hidden.shape == input_mask_expanded.shape + input_mask_expanded = input_mask_expanded * weights + + sum_embeddings = torch.sum(last_hidden * input_mask_expanded, 1) + sum_mask = input_mask_expanded.sum(1) + sum_mask = torch.clamp(sum_mask, min=1e-9) + pooled_output = sum_embeddings / sum_mask + + else: + raise NotImplementedError + + return pooled_output + + +@MODELS.register_module( + group_key=Tasks.sentence_embedding, module_name=Models.bloom) +class BloomForSentenceEmbedding(BloomModelTransform, TorchModel): + r""" + This model represent a text to a dense vector by the last token state or weighted mean of all token states. + See `Language Models are Universal Embedders + `_ for details. + """ + + def __init__(self, config, **kwargs): + super().__init__(config) + self.config = config + self.pooler_type = kwargs.get('emb_pooler_type', 'weighted_mean') + self.pooler = DecoderPooler(self.pooler_type) + self.normalize = kwargs.get('normalize', False) + setattr(self, self.base_model_prefix, BloomModelTransform(config)) + + def forward(self, query=None, docs=None, labels=None): + r""" + Args: + query (:obj: `dict`): Dict of pretrained models's input for the query sequence. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` + for details. + docs (:obj: `dict`): Dict of pretrained models's input for the query sequence. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` + for details. + Returns: + Returns `modelscope.outputs.SentencEmbeddingModelOutput + Examples: + >>> from modelscope.models import Model + >>> from modelscope.preprocessors import Preprocessor + >>> model = Model.from_pretrained('damo/nlp_udever_bloom_560m') + >>> preprocessor = Preprocessor.from_pretrained('damo/nlp_udever_bloom_560m') + >>> inputs = preprocessor({'source_sentence': ['This is a test']}) + >>> outputs = model(**inputs) + >>> print(outputs) + """ + query_embeddings, doc_embeddings = None, None + if query is not None: + query_embeddings = self.encode(**query) + if docs is not None: + doc_embeddings = self.encode(**docs) + outputs = SentencEmbeddingModelOutput( + query_embeddings=query_embeddings, doc_embeddings=doc_embeddings) + if query_embeddings is None or doc_embeddings is None: + return outputs + if self.base_model.training: + loss_fct = torch.nn.CrossEntropyLoss() + scores = torch.matmul(query_embeddings, doc_embeddings.T) + if labels is None: + labels = torch.arange( + scores.size(0), device=scores.device, dtype=torch.long) + labels = labels * ( + doc_embeddings.size(0) // query_embeddings.size(0)) + loss = loss_fct(scores, labels) + outputs.loss = loss + return outputs + + def encode( + self, + input_ids=None, + attention_mask=None, + ): + outputs = self.base_model.forward( + input_ids, attention_mask=attention_mask) + embeddings = self.pooler(outputs, attention_mask) + if self.normalize: + embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=-1) + return embeddings + + @classmethod + def _instantiate(cls, **kwargs): + """Instantiate the model. + + Args: + kwargs: Input args. + model_dir: The model dir used to load the checkpoint and the label information. + + Returns: + The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained + """ + model_dir = kwargs.get('model_dir') + model_kwargs = { + 'emb_pooler_type': kwargs.get('emb_pooler_type', 'weighted_mean'), + 'normalize': kwargs.get('normalize', False) + } + if model_dir is None: + config = BloomConfig(**kwargs) + model = cls(config) + else: + model = super(BloomModelTransform, cls).from_pretrained( + pretrained_model_name_or_path=model_dir, **model_kwargs) + model.model_dir = model_dir + return model diff --git a/modelscope/preprocessors/nlp/sentence_embedding_preprocessor.py b/modelscope/preprocessors/nlp/sentence_embedding_preprocessor.py index b03268c6..f1ca6685 100644 --- a/modelscope/preprocessors/nlp/sentence_embedding_preprocessor.py +++ b/modelscope/preprocessors/nlp/sentence_embedding_preprocessor.py @@ -1,14 +1,19 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from typing import Any, Dict +from typing import Any, Dict, Optional + +import torch from modelscope.metainfo import Preprocessors from modelscope.preprocessors import Preprocessor from modelscope.preprocessors.builder import PREPROCESSORS from modelscope.utils.constant import Fields, ModeKeys from modelscope.utils.hub import get_model_type +from modelscope.utils.logger import get_logger from .transformers_tokenizer import NLPTokenizer +logger = get_logger() + @PREPROCESSORS.register_module( Fields.nlp, module_name=Preprocessors.sentence_embedding) @@ -46,9 +51,32 @@ class SentenceEmbeddingTransformersPreprocessor(Preprocessor): self.max_length = max_length if model_dir is not None: model_type = get_model_type(model_dir) + # we could add `boq/bod` token/prompt and `eoq/eod` token if they exist when tokenizing. + for k in ('boq', 'eoq', 'bod', 'eod'): + setattr(self, k, kwargs.pop(k, None)) self.nlp_tokenizer = NLPTokenizer( model_dir, model_type, use_fast=use_fast, tokenize_kwargs=kwargs) super().__init__(mode=mode) + tokenizer = self.nlp_tokenizer.tokenizer + # For tokenizers like bloom + if tokenizer.padding_side != 'right': + # weighted mean pooling need pad right + logger.warning( + f'Change tokenizer.padding_side from {tokenizer.padding_side} to right' + ) + tokenizer.padding_side = 'right' + # For decoder-only tokenizers + if tokenizer.pad_token is None: + logger.warning( + f'Set tokenizer.pad_token as eos_token {tokenizer.eos_token}') + tokenizer.pad_token = tokenizer.eos_token + # Currently eos is single token, we can extend to prompt later. + for k in ('eoq', 'eod'): + v = getattr(self, k, None) + if v is not None: + v = tokenizer.convert_tokens_to_ids(v) + setattr(self, k + '_id', v) + self.pad_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token) def __call__(self, data: Dict, @@ -81,13 +109,80 @@ class SentenceEmbeddingTransformersPreprocessor(Preprocessor): if 'return_tensors' not in kwargs: kwargs[ 'return_tensors'] = 'pt' if self.mode == ModeKeys.INFERENCE else None - query_inputs = self.nlp_tokenizer( - source_sentences, padding=padding, truncation=truncation, **kwargs) + query_inputs = self.tokenize( + source_sentences, + is_query=True, + padding=padding, + truncation=truncation, + **kwargs) tokenized_inputs = {'query': query_inputs, 'docs': None} if compare_sentences is not None and len(compare_sentences) > 0: - tokenized_inputs['docs'] = self.nlp_tokenizer( + tokenized_inputs['docs'] = self.tokenize( compare_sentences, + is_query=kwargs.get('symmetric', False), padding=padding, truncation=truncation, **kwargs) return tokenized_inputs + + def tokenize(self, texts, is_query=True, return_tensors=None, **kwargs): + """Tokenize raw texts, add `boq/bod` token/prompt and `eoq/eod` token if they exist. + + Args: + `texts` List[str]: texts to tokenize, + Example: + ["how long it take to get a master's degree"] + `is_query` bool: whether the input text(s) is query. + `return_tensors` str: the `return_tensors` argument to tokenizer. + Returns: + Dict[str, Any]: the preprocessed data + """ + if is_query: + bos, eos_id = self.boq, self.eoq_id + else: + bos, eos_id = self.bod, self.eod_id + if bos is not None: + # bos can be prompt + texts = [bos + t for t in texts] + encoding = self.nlp_tokenizer( + texts, return_tensors=return_tensors, **kwargs) + if eos_id is not None: + if return_tensors == 'pt': + self.add_eos_pt(encoding, eos_id) + else: + self.add_eos(encoding, eos_id) + return encoding + + def add_eos_pt(self, encoding: Dict[str, torch.Tensor], eos: int): + """Add `eos` token id to the end of each sequence.""" + input_ids, attn_mask = encoding['input_ids'], encoding[ + 'attention_mask'] + batch = torch.arange(input_ids.size(0)) + length = attn_mask.sum(-1) + + if input_ids.size(1) < self.max_length: + ones = input_ids.new_ones(input_ids.size(0), 1) + attn_mask = torch.cat((ones, attn_mask), dim=1) + padding = ones * self.pad_id + input_ids = torch.cat((input_ids, padding), dim=1) + eos_index = length + else: + eos_index = torch.clamp(length, max=self.max_length - 1) + attn_mask[batch, eos_index] = 1 + input_ids[batch, eos_index] = eos + encoding['input_ids'], encoding[ + 'attention_mask'] = input_ids, attn_mask + return + + def add_eos(self, encoding: Dict[str, list], eos: int): + """Add `eos` token id to the end of each sequence.""" + for ids, mask in zip(encoding['input_ids'], + encoding['attention_mask']): + if len(mask) < self.max_length: + ids.append(eos) + mask.append(1) + else: + last = min(sum(mask), self.max_length - 1) + ids[last] = eos + mask[last] = 1 + return diff --git a/tests/pipelines/test_sentence_embedding.py b/tests/pipelines/test_sentence_embedding.py index 13260132..a6dd89ec 100644 --- a/tests/pipelines/test_sentence_embedding.py +++ b/tests/pipelines/test_sentence_embedding.py @@ -21,6 +21,7 @@ class SentenceEmbeddingTest(unittest.TestCase): medical_tiny_model_id = 'damo/nlp_corom_sentence-embedding_chinese-tiny-medical' general_base_model_id = 'damo/nlp_corom_sentence-embedding_chinese-base' general_tiny_model_id = 'damo/nlp_corom_sentence-embedding_chinese-tiny' + bloom_model_id = 'damo/udever-bloom-7b1' inputs = { 'source_sentence': ["how long it take to get a master's degree"], @@ -154,6 +155,14 @@ class SentenceEmbeddingTest(unittest.TestCase): print() print(f'pipeline2: {pipeline2(input=self.medical_inputs1)}') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_bloom_model_from_modelhub(self): + model = Model.from_pretrained(self.bloom_model_id) + tokenizer = SentenceEmbeddingTransformersPreprocessor(model.model_dir) + pipeline_ins = pipeline( + task=Tasks.sentence_embedding, model=model, preprocessor=tokenizer) + print(pipeline_ins(input=self.inputs)) + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_model_from_modelhub(self): model = Model.from_pretrained(self.model_id)