From 72f8c43ccaa6bc19c5e27d32953f46687ab268a2 Mon Sep 17 00:00:00 2001 From: "siyang.ssy" Date: Tue, 26 Jul 2022 22:48:07 +0800 Subject: [PATCH] [to #42322933]add video multi-model feature Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9458640 --- .gitignore | 1 + .../videos/multi_modal_test_video_9770.mp4 | 3 + modelscope/metainfo.py | 2 + modelscope/models/__init__.py | 10 +- modelscope/models/multi_modal/__init__.py | 2 + .../models/multi_modal/imagen/structbert.py | 2 +- modelscope/models/multi_modal/mmr/__init__.py | 0 .../multi_modal/mmr/dataloaders/__init__.py | 0 .../mmr/dataloaders/rawvideo_util.py | 114 ++++ .../models/multi_modal/mmr/models/__init__.py | 0 .../clip_for_multi_model_video_embedding.py | 218 ++++++++ .../mmr/models/dynamic_inverted_softmax.py | 42 ++ .../models/multi_modal/mmr/models/modeling.py | 508 +++++++++++++++++ .../multi_modal/mmr/models/module_clip.py | 526 ++++++++++++++++++ .../multi_modal/mmr/models/module_cross.py | 100 ++++ .../mmr/models/tokenization_clip.py | 158 ++++++ .../multi_modal/mmr/models/until_module.py | 120 ++++ .../nlp/bert_for_sequence_classification.py | 2 +- .../models/nlp/palm_for_text_generation.py | 3 +- .../nlp/space_for_dialog_intent_prediction.py | 3 +- .../nlp/space_for_dialog_state_tracking.py | 2 +- modelscope/msdatasets/ms_dataset.py | 2 +- modelscope/pipelines/builder.py | 3 + modelscope/pipelines/multi_modal/__init__.py | 5 +- .../video_multi_modal_embedding_pipeline.py | 42 ++ modelscope/pipelines/nlp/__init__.py | 2 +- .../dialog_state_tracking_preprocessor.py | 2 +- .../nlp/sequence_classification_trainer.py | 3 +- modelscope/utils/constant.py | 2 +- modelscope/utils/hub.py | 2 +- setup.py | 2 +- .../test_video_multi_modal_embedding.py | 45 ++ 32 files changed, 1908 insertions(+), 18 deletions(-) create mode 100644 data/test/videos/multi_modal_test_video_9770.mp4 create mode 100644 modelscope/models/multi_modal/mmr/__init__.py create mode 100644 modelscope/models/multi_modal/mmr/dataloaders/__init__.py create mode 100644 modelscope/models/multi_modal/mmr/dataloaders/rawvideo_util.py create mode 100644 modelscope/models/multi_modal/mmr/models/__init__.py create mode 100644 modelscope/models/multi_modal/mmr/models/clip_for_multi_model_video_embedding.py create mode 100644 modelscope/models/multi_modal/mmr/models/dynamic_inverted_softmax.py create mode 100644 modelscope/models/multi_modal/mmr/models/modeling.py create mode 100644 modelscope/models/multi_modal/mmr/models/module_clip.py create mode 100644 modelscope/models/multi_modal/mmr/models/module_cross.py create mode 100644 modelscope/models/multi_modal/mmr/models/tokenization_clip.py create mode 100644 modelscope/models/multi_modal/mmr/models/until_module.py create mode 100644 modelscope/pipelines/multi_modal/video_multi_modal_embedding_pipeline.py create mode 100644 tests/pipelines/test_video_multi_modal_embedding.py diff --git a/.gitignore b/.gitignore index 05929ea9..8a0db7fa 100644 --- a/.gitignore +++ b/.gitignore @@ -124,3 +124,4 @@ replace.sh # Pytorch *.pth +*.pt diff --git a/data/test/videos/multi_modal_test_video_9770.mp4 b/data/test/videos/multi_modal_test_video_9770.mp4 new file mode 100644 index 00000000..45245b52 --- /dev/null +++ b/data/test/videos/multi_modal_test_video_9770.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:33e21c16d5388684b61d7251b9d4e418f8146c3ba3fa400ebd8d913058687cfc +size 431888 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index d4bb64aa..3a38e8d3 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -34,6 +34,7 @@ class Models(object): gemm = 'gemm-generative-multi-modal' mplug = 'mplug' imagen = 'imagen-text-to-image-synthesis' + video_clip = 'video-clip-multi-modal-embedding' class TaskModels(object): @@ -99,6 +100,7 @@ class Pipelines(object): generative_multi_modal_embedding = 'generative-multi-modal-embedding' visual_question_answering = 'visual-question-answering' text_to_image_synthesis = 'text-to-image-synthesis' + video_multi_modal_embedding = 'video-multi-modal-embedding' class Trainers(object): diff --git a/modelscope/models/__init__.py b/modelscope/models/__init__.py index bc24eef6..95af2047 100644 --- a/modelscope/models/__init__.py +++ b/modelscope/models/__init__.py @@ -5,10 +5,10 @@ from .base import Model from .builder import MODELS, build_model try: - from .audio.asr import GenericAutomaticSpeechRecognition - from .audio.tts import SambertHifigan - from .audio.kws import GenericKeyWordSpotting from .audio.ans.frcrn import FRCRNModel + from .audio.asr import GenericAutomaticSpeechRecognition + from .audio.kws import GenericKeyWordSpotting + from .audio.tts import SambertHifigan except ModuleNotFoundError as e: print(AUDIO_IMPORT_ERROR.format(e)) @@ -29,8 +29,8 @@ try: SbertForZeroShotClassification, SpaceForDialogIntent, SpaceForDialogModeling, SpaceForDialogStateTracking, StructBertForMaskedLM, VecoForMaskedLM) - from .nlp.heads import (SequenceClassificationHead) - from .nlp.backbones import (SbertModel) + from .nlp.backbones import SbertModel + from .nlp.heads import SequenceClassificationHead except ModuleNotFoundError as e: if str(e) == "No module named 'pytorch'": pass diff --git a/modelscope/models/multi_modal/__init__.py b/modelscope/models/multi_modal/__init__.py index 89db0290..14c791b0 100644 --- a/modelscope/models/multi_modal/__init__.py +++ b/modelscope/models/multi_modal/__init__.py @@ -1,6 +1,8 @@ from .clip.clip_model import CLIPForMultiModalEmbedding from .gemm.gemm_model import GEMMForMultiModalEmbedding from .imagen.imagen_model import ImagenForTextToImageSynthesis +from .mmr.models.clip_for_multi_model_video_embedding import \ + VideoCLIPForMultiModalEmbedding from .mplug_for_visual_question_answering import \ MPlugForVisualQuestionAnswering from .ofa_for_image_captioning_model import OfaForImageCaptioning diff --git a/modelscope/models/multi_modal/imagen/structbert.py b/modelscope/models/multi_modal/imagen/structbert.py index 219e642f..d5d678ed 100644 --- a/modelscope/models/multi_modal/imagen/structbert.py +++ b/modelscope/models/multi_modal/imagen/structbert.py @@ -784,7 +784,7 @@ class BertModel(nn.Module): elif config.transformer_type.lower() == 'act': self.encoder = BERTEncoderACT(config) elif config.transformer_type.lower() == 'textnas': - from textnas_final import op_dict, input_dict, skip_dict + from textnas_final import input_dict, op_dict, skip_dict self.encoder = TextNASEncoder(config, op_dict, input_dict, skip_dict) else: diff --git a/modelscope/models/multi_modal/mmr/__init__.py b/modelscope/models/multi_modal/mmr/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/multi_modal/mmr/dataloaders/__init__.py b/modelscope/models/multi_modal/mmr/dataloaders/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/multi_modal/mmr/dataloaders/rawvideo_util.py b/modelscope/models/multi_modal/mmr/dataloaders/rawvideo_util.py new file mode 100644 index 00000000..eab1189f --- /dev/null +++ b/modelscope/models/multi_modal/mmr/dataloaders/rawvideo_util.py @@ -0,0 +1,114 @@ +import cv2 +import numpy as np +import torch as th +from PIL import Image +from torchvision.transforms import (CenterCrop, Compose, InterpolationMode, + Normalize, Resize, ToTensor) + +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +class RawVideoExtractorCV2(): + + def __init__( + self, + centercrop=False, + size=224, + frame_rate=-1, + ): + self.centercrop = centercrop + self.size = size + self.framerate = frame_rate + self.transform = self._transform(self.size) + + def _transform(self, n_px): + return Compose([ + Resize(n_px, interpolation=InterpolationMode.BICUBIC), + CenterCrop(n_px), + lambda image: image.convert('RGB'), + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711)), + ]) + + def video_to_tensor(self, + video_file, + preprocess, + sample_fp=0, + start_time=None, + end_time=None): + if start_time is not None or end_time is not None: + assert isinstance(start_time, int) and isinstance(end_time, int) \ + and start_time > -1 and end_time > start_time + assert sample_fp > -1 + + # Samples a frame sample_fp X frames. + cap = cv2.VideoCapture(video_file) + frameCount = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + fps = int(cap.get(cv2.CAP_PROP_FPS)) + + if fps == 0: + logger.info(f'{video_file} with fps 0!!!') + total_duration = (frameCount + fps - 1) // fps + start_sec, end_sec = 0, total_duration + + if start_time is not None: + start_sec, end_sec = start_time, end_time if end_time <= total_duration else total_duration + cap.set(cv2.CAP_PROP_POS_FRAMES, int(start_time * fps)) + + interval = 1 + if sample_fp > 0: + interval = fps // sample_fp + else: + sample_fp = fps + if interval == 0: + interval = 1 + + inds = [ind for ind in np.arange(0, fps, interval)] + assert len(inds) >= sample_fp + inds = inds[:sample_fp] + + ret = True + images = [] + + for sec in np.arange(start_sec, end_sec + 1): + if not ret: + break + sec_base = int(sec * fps) + for ind in inds: + cap.set(cv2.CAP_PROP_POS_FRAMES, sec_base + ind) + ret, frame = cap.read() + if not ret: + break + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + images.append( + preprocess(Image.fromarray(frame_rgb).convert('RGB'))) + + cap.release() + + if len(images) > 0: + video_data = th.tensor(np.stack(images)) + else: + video_data = th.zeros(1) + return {'video': video_data} + + def get_video_data(self, video_path, start_time=None, end_time=None): + image_input = self.video_to_tensor( + video_path, + self.transform, + sample_fp=self.framerate, + start_time=start_time, + end_time=end_time) + return image_input + + def process_raw_data(self, raw_video_data): + tensor_size = raw_video_data.size() + tensor = raw_video_data.view(-1, 1, tensor_size[-3], tensor_size[-2], + tensor_size[-1]) + return tensor + + +# An ordinary video frame extractor based CV2 +RawVideoExtractor = RawVideoExtractorCV2 diff --git a/modelscope/models/multi_modal/mmr/models/__init__.py b/modelscope/models/multi_modal/mmr/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/multi_modal/mmr/models/clip_for_multi_model_video_embedding.py b/modelscope/models/multi_modal/mmr/models/clip_for_multi_model_video_embedding.py new file mode 100644 index 00000000..426581db --- /dev/null +++ b/modelscope/models/multi_modal/mmr/models/clip_for_multi_model_video_embedding.py @@ -0,0 +1,218 @@ +import os +import random +from os.path import exists +from typing import Any, Dict + +import json +import numpy as np +import torch +from PIL import Image + +from modelscope.metainfo import Models +from modelscope.models.base import Model +from modelscope.models.builder import MODELS +from modelscope.models.multi_modal.mmr.dataloaders.rawvideo_util import \ + RawVideoExtractor +from modelscope.models.multi_modal.mmr.models.modeling import CLIP4Clip +from modelscope.models.multi_modal.mmr.models.tokenization_clip import \ + SimpleTokenizer as ClipTokenizer +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@MODELS.register_module( + Tasks.video_multi_modal_embedding, module_name=Models.video_clip) +class VideoCLIPForMultiModalEmbedding(Model): + + def __init__(self, model_dir, device_id=-1): + super().__init__(model_dir=model_dir, device_id=device_id) + # model config parameters + with open(f'{model_dir}/{ModelFile.CONFIGURATION}', 'r') as json_file: + model_config = json.load(json_file) + model_config = model_config['paras'] + model_config['model_dir'] = model_dir + self.SPECIAL_TOKEN = { + 'CLS_TOKEN': '<|startoftext|>', + 'SEP_TOKEN': '<|endoftext|>', + 'MASK_TOKEN': '[MASK]', + 'UNK_TOKEN': '[UNK]', + 'PAD_TOKEN': '[PAD]' + } + self.max_words = model_config['max_words'] + self.max_frames = model_config['max_frames'] + self.feature_framerate = model_config['feature_framerate'] + self.image_resolution = 224 + self.device = model_config['device'] + self.init_model = f'{model_dir}/{ModelFile.TORCH_MODEL_BIN_FILE}' + + self.tokenizer = ClipTokenizer(model_dir) + self.rawVideoExtractor = RawVideoExtractor( + frame_rate=self.feature_framerate, size=self.image_resolution) + self.local_transform = self.rawVideoExtractor.transform + + self.model = CLIP4Clip(model_config) + if hasattr(self.model, 'module'): + self.model = self.model.module.to(self.device) + else: + self.model = self.model.to(self.device) + if self.init_model: + assert exists(self.init_model) + model_state_dict = torch.load(self.init_model, map_location='cpu') + self.model.load_state_dict(model_state_dict, strict=False) + self.model.to(self.device) + + def _get_text(self, caption, tokenizer, enable_zh=False): + if len(caption) == 3: + _caption_text, s, e = caption + elif len(caption) == 4: + _caption_text, s, e, pos = caption + else: + NotImplementedError + + if isinstance(_caption_text, list): + caption_text = random.choice(_caption_text) + else: + caption_text = _caption_text + if enable_zh: + _token = tokenizer.encode(caption_text) + input_ids = _token.ids + input_mask = _token.attention_mask + segment_ids = _token.type_ids + else: + words = tokenizer.tokenize(caption_text) + + words = [self.SPECIAL_TOKEN['CLS_TOKEN']] + words + total_length_with_CLS = self.max_words - 1 + if len(words) > total_length_with_CLS: + words = words[:total_length_with_CLS] + words = words + [self.SPECIAL_TOKEN['SEP_TOKEN']] + + input_ids = tokenizer.convert_tokens_to_ids(words) + input_mask = [1] * len(input_ids) + segment_ids = [0] * len(input_ids) + + while len(input_ids) < self.max_words: + input_ids.append(0) + input_mask.append(0) + segment_ids.append(0) + assert len(input_ids) == self.max_words + assert len(input_mask) == self.max_words + assert len(segment_ids) == self.max_words + + pairs_text = np.array(input_ids) + pairs_mask = np.array(input_mask) + pairs_segment = np.array(segment_ids) + + return pairs_text, pairs_mask, pairs_segment, s, e + + def _get_rawvideo_dec(self, + video_path, + rawVideoExtractor, + local_transform, + s=None, + e=None): + video_mask = np.zeros(self.max_frames, dtype=np.long) + max_video_length = 0 + + # T x 3 x H x W + video = np.zeros((self.max_frames, 3, rawVideoExtractor.size, + rawVideoExtractor.size), + dtype=np.float) + + if s is None: + start_time, end_time = None, None + else: + start_time = int(s) + end_time = int(e) + start_time = start_time if start_time >= 0. else 0. + end_time = end_time if end_time >= 0. else 0. + if start_time > end_time: + start_time, end_time = end_time, start_time + elif start_time == end_time: + end_time = end_time + 1 + + if exists(video_path): + from decord import VideoReader, cpu + vreader = VideoReader(video_path, ctx=cpu(0)) + else: + logger.error('non video input, output is wrong!!!') + return video, video_mask + + fps = vreader.get_avg_fps() + f_start = 0 if start_time is None else int(start_time * fps) + f_end = int( + min(1000000000 if end_time is None else end_time * fps, + len(vreader) - 1)) + num_frames = f_end - f_start + 1 + if num_frames > 0: + # L x T x 3 x H x W + sample_fps = int(self.feature_framerate) + t_stride = int(round(float(fps) / sample_fps)) + + all_pos = list(range(f_start, f_end + 1, t_stride)) + if len(all_pos) > self.max_frames: + sample_pos = [ + all_pos[_] for _ in np.linspace( + 0, len(all_pos) - 1, num=self.max_frames, dtype=int) + ] + else: + sample_pos = all_pos + patch_images = [ + Image.fromarray(f) + for f in vreader.get_batch(sample_pos).asnumpy() + ] + patch_images = torch.stack( + [local_transform(img) for img in patch_images]) + slice_len = patch_images.shape[0] + max_video_length = max_video_length if max_video_length > slice_len else slice_len + if slice_len < 1: + pass + else: + video[:slice_len, ...] = patch_images + else: + logger.error('video path: {} error. video id: {}'.format( + video_path, video_id)) + + video_mask[:max_video_length] = [1] * max_video_length + + return video, video_mask + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + + from modelscope.outputs import OutputKeys + output = {} + + if 'video' in input and input['video'] is not None: + video_path = input['video'] + video, video_mask = self._get_rawvideo_dec(video_path, + self.rawVideoExtractor, + self.local_transform) + video = torch.unsqueeze( + torch.from_numpy(video), dim=0).to(self.device) + video_mask = torch.unsqueeze( + torch.from_numpy(video_mask), dim=0).to(self.device) + + if 'text' in input and input['text'] is not None: + caption = input['text'] + pairs_text, pairs_mask, pairs_segment, s, e = self._get_text( + caption, self.tokenizer, enable_zh=False) + input_ids = torch.unsqueeze( + torch.from_numpy(pairs_text), dim=0).to(self.device) + input_mask = torch.unsqueeze( + torch.from_numpy(pairs_mask), dim=0).to(self.device) + segment_ids = torch.unsqueeze( + torch.from_numpy(pairs_segment), dim=0).to(self.device) + + sequence_output, visual_output = self.model.get_sequence_visual_output( + input_ids, segment_ids, input_mask, video, video_mask) + logger.info('text feature: {}'.format(sequence_output[0][0][0])) + logger.info('video feature: {}'.format(visual_output[0][0][0])) + + output[OutputKeys.VIDEO_EMBEDDING] = visual_output + output[OutputKeys.TEXT_EMBEDDING] = sequence_output + return output + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/models/multi_modal/mmr/models/dynamic_inverted_softmax.py b/modelscope/models/multi_modal/mmr/models/dynamic_inverted_softmax.py new file mode 100644 index 00000000..572f44bc --- /dev/null +++ b/modelscope/models/multi_modal/mmr/models/dynamic_inverted_softmax.py @@ -0,0 +1,42 @@ +import numpy as np + + +def get_retrieved_videos(sims, k): + """ + Returns list of retrieved top k videos based on the sims matrix + Args: + sims: similar matrix. + K: top k number of videos + """ + argm = np.argsort(-sims, axis=1) + topk = argm[:, :k].reshape(-1) + retrieved_videos = np.unique(topk) + return retrieved_videos + + +def get_index_to_normalize(sims, videos): + """ + Returns list of indices to normalize from sims based on videos + Args: + sims: similar matrix. + videos: video array. + """ + argm = np.argsort(-sims, axis=1)[:, 0] + result = np.array(list(map(lambda x: x in videos, argm))) + result = np.nonzero(result) + return result + + +def qb_norm(train_test, test_test, args): + k = args.get('k', 1) + beta = args.get('beta', 20) + retrieved_videos = get_retrieved_videos(train_test, k) + test_test_normalized = test_test + train_test = np.exp(train_test * beta) + test_test = np.exp(test_test * beta) + + normalizing_sum = np.sum(train_test, axis=0) + index_for_normalizing = get_index_to_normalize(test_test, retrieved_videos) + test_test_normalized[index_for_normalizing, :] = \ + np.divide(test_test[index_for_normalizing, :], normalizing_sum) + return test_test_normalized diff --git a/modelscope/models/multi_modal/mmr/models/modeling.py b/modelscope/models/multi_modal/mmr/models/modeling.py new file mode 100644 index 00000000..214e65c7 --- /dev/null +++ b/modelscope/models/multi_modal/mmr/models/modeling.py @@ -0,0 +1,508 @@ +import os +import platform +from collections import OrderedDict +from types import SimpleNamespace + +import torch +from torch import nn +from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence + +from modelscope.models.multi_modal.mmr.models.module_clip import ( + _PT_NAME, CLIP, QuickGELU, convert_weights) +from modelscope.models.multi_modal.mmr.models.module_cross import \ + Transformer as TransformerClip +from modelscope.models.multi_modal.mmr.models.until_module import (AllGather, + CrossEn, + LayerNorm) +from modelscope.utils.logger import get_logger + +allgather = AllGather.apply + +logger = get_logger() +__all__ = ['CLIP4Clip'] + + +class CLIP4Clip(nn.Module): + + def __init__(self, config): + super(CLIP4Clip, self).__init__() + + self.config = config + self.loose_type = config['loose_type'] + self.sim_header = config['sim_header'] + if self.sim_header in [ + 'tightTransf', 'tightFc1', 'tightFc2', 'tightFc3', 'tightFc4', + 'tightMean', 'tightFc5' + ]: + assert self.loose_type is False + + backbone = config['pretrained_clip_name'] + + # fix backbone without downlond + model_path = '{}/ViT-B-16.pt'.format(config['model_dir']) + if not os.path.exists(model_path): + logger.info('no model loaded!!!') + + try: + # loading JIT archive + model = torch.jit.load(model_path, map_location='cpu').eval() + state_dict = model.state_dict() + except RuntimeError: + state_dict = torch.load(model_path, map_location='cpu') + + vision_width = state_dict['visual.conv1.weight'].shape[0] + vision_layers = len([ + k for k in state_dict.keys() + if k.startswith('visual.') and k.endswith('.attn.in_proj_weight') + ]) + vision_patch_size = state_dict['visual.conv1.weight'].shape[-1] + grid_size = round( + (state_dict['visual.positional_embedding'].shape[0] - 1)**0.5) + image_resolution = vision_patch_size * grid_size + + embed_dim = state_dict['text_projection'].shape[1] + context_length = state_dict['positional_embedding'].shape[0] + vocab_size = state_dict['token_embedding.weight'].shape[0] + transformer_width = state_dict['ln_final.weight'].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len( + set( + k.split('.')[2] for k in state_dict + if k.startswith('transformer.resblocks'))) + + cut_top_layer = 0 + self.clip = CLIP( + embed_dim, + image_resolution, + vision_layers - cut_top_layer, + vision_width, + vision_patch_size, + context_length, + vocab_size, + transformer_width, + transformer_heads, + transformer_layers - cut_top_layer, + linear_patch=config['linear_patch'], + use_gc=config['use_gc']).float() + + if (platform.system() != 'Darwin'): + convert_weights(self.clip) # fp16 + + if backbone in ['ViT-B/32', 'ViT-B/16']: + cross_config = SimpleNamespace(**{ + 'hidden_size': 512, + 'max_position_embeddings': 128, + }) + elif backbone in ['ViT-L/14', 'ViT-B/14-336px']: + cross_config = SimpleNamespace(**{ + 'hidden_size': 768, + 'max_position_embeddings': 128, + }) + else: + raise ValueError + + cross_config.max_position_embeddings = context_length + self.cross_config = cross_config + + self.text_weight_fc = nn.Sequential( + nn.Linear(transformer_width, transformer_width), + nn.ReLU(inplace=True), nn.Linear(transformer_width, 1)) + self.video_weight_fc = nn.Sequential( + nn.Linear(transformer_width, transformer_width), + nn.ReLU(inplace=True), nn.Linear(transformer_width, 1)) + + if self.loose_type is False: + raise NotImplementedError + + if self.sim_header in ['seqLSTM', 'seqTransf', 'tightFc1']: + self.frame_position_embeddings = nn.Embedding( + cross_config.max_position_embeddings, cross_config.hidden_size) + if self.sim_header in ['seqTransf', 'tightFc1']: + self.transformerClip = TransformerClip( + width=transformer_width, + layers=config['cross_num_hidden_layers'], + heads=transformer_heads, + ) + if self.sim_header == 'seqLSTM': + self.lstm_visual = nn.LSTM( + input_size=cross_config.hidden_size, + hidden_size=cross_config.hidden_size, + batch_first=True, + bidirectional=False, + num_layers=1) + + self.loss_fct = CrossEn(config) + + self.apply(self.init_weights) + self.clip.load_state_dict(state_dict, strict=False) + + # ===> Initialization trick [HARD CODE] + if backbone not in _PT_NAME: + raise NotImplementedError + # reload + else: + if config['linear_patch'] == '3d': + raise NotImplementedError + + new_state_dict = OrderedDict() + if self.sim_header == 'tightTransf': + raise NotImplementedError + + if self.sim_header in ['seqLSTM', 'seqTransf', 'seqFc1']: + contain_frame_position = False + for key in state_dict.keys(): + if key.find('frame_position_embeddings') > -1: + contain_frame_position = True + break + if contain_frame_position is False: + for key, val in state_dict.items(): + if key == 'positional_embedding': + new_state_dict[ + 'frame_position_embeddings.weight'] = val.clone() + continue + if self.sim_header in [ + 'seqTransf', 'seqFc1' + ] and key.find('transformer.resblocks') == 0: + num_layer = int(key.split('.')[2]) + # cut from beginning + if num_layer < config['cross_num_hidden_layers']: + new_state_dict[key.replace( + 'transformer.', + 'transformerClip.')] = val.clone() + continue + # <=== End of initialization trick + + self.load_state_dict( + new_state_dict, strict=False + ) # only update new state (seqTransf/seqLSTM/tightTransf) + if self.sim_header == 'tightFc5': + raise ValueError + + def forward(self, + input_ids, + token_type_ids, + attention_mask, + video, + video_mask=None): + input_ids = input_ids.view(-1, input_ids.shape[-1]) + token_type_ids = token_type_ids.view(-1, token_type_ids.shape[-1]) + attention_mask = attention_mask.view(-1, attention_mask.shape[-1]) + video_mask = video_mask.view(-1, video_mask.shape[-1]) + + # B x T x 3 x H x W - > (B x T) x 3 x H x W + video = torch.as_tensor(video).float() + if len(video.shape) == 6: # image + b, bs, ts, channel, h, w = video.shape + b = b * bs + else: # video + b, ts, channel, h, w = video.shape + video = video.view(b * ts, channel, h, w) + + sequence_output, visual_output = self.get_sequence_visual_output( + input_ids, + token_type_ids, + attention_mask, + video, + video_mask, + shaped=True) + + if self.training: + loss = 0. + sim_matrix1, sim_matrix2, barlow_loss = self.get_similarity_logits( + sequence_output, + visual_output, + attention_mask, + video_mask, + shaped=True, + loose_type=self.loose_type) + sim_loss = (self.loss_fct(sim_matrix1) + + self.loss_fct(sim_matrix2)) / 2 + loss += sim_loss + barlow_loss * self.config.cdcr_lambda + + return loss + else: + return None + + def get_sequence_output(self, + input_ids, + token_type_ids, + attention_mask, + shaped=False): + if shaped is False: + input_ids = input_ids.view(-1, input_ids.shape[-1]) + token_type_ids = token_type_ids.view(-1, token_type_ids.shape[-1]) + attention_mask = attention_mask.view(-1, attention_mask.shape[-1]) + + bs_pair = input_ids.size(0) + sequence_hidden = self.clip.encode_text( + input_ids, return_hidden=True, prompt=None)[1].float() + sequence_hidden = sequence_hidden.view(bs_pair, -1, + sequence_hidden.size(-1)) + + return sequence_hidden + + def get_visual_output(self, video, video_mask, shaped=False): + if shaped is False: + video_mask = video_mask.view(-1, video_mask.shape[-1]) + video = torch.as_tensor(video).float() + b, ts, channel, h, w = video.shape + video = video.view(b * ts, channel, h, w) + + bs_pair = video_mask.size(0) + visual_hidden = self.clip.encode_image(video).float() + visual_hidden = visual_hidden.float().view(bs_pair, -1, + visual_hidden.size(-1)) + + return visual_hidden + + def get_sequence_visual_output(self, + input_ids, + token_type_ids, + attention_mask, + video, + video_mask, + shaped=False): + if shaped is False: + input_ids = input_ids.view(-1, input_ids.shape[-1]) + token_type_ids = token_type_ids.view(-1, token_type_ids.shape[-1]) + attention_mask = attention_mask.view(-1, attention_mask.shape[-1]) + video_mask = video_mask.view(-1, video_mask.shape[-1]) + + video = torch.as_tensor(video).float() + if len(video.shape) == 6: # image + b, bs, ts, channel, h, w = video.shape + b = b * bs + else: # video + b, ts, channel, h, w = video.shape + video = video.view(b * ts, channel, h, w) + + sequence_output = self.get_sequence_output( + input_ids, token_type_ids, attention_mask, shaped=True) + visual_output = self.get_visual_output(video, video_mask, shaped=True) + + return sequence_output, visual_output + + def agg_video_feat(self, visual_output, video_mask, sim_header='meanP'): + if self.config.max_sum == 0: + raise ValueError + + if sim_header == 'meanP': + # Default: Parameter-free type + pass + elif sim_header == 'seqLSTM': + # Sequential type: LSTM + visual_output_original = visual_output + visual_output = pack_padded_sequence( + visual_output, + torch.sum(video_mask, dim=-1).cpu(), + batch_first=True, + enforce_sorted=False) + visual_output, _ = self.lstm_visual(visual_output) + if self.training: + self.lstm_visual.flatten_parameters() + visual_output, _ = pad_packed_sequence( + visual_output, batch_first=True) + visual_output = torch.cat( + (visual_output, visual_output_original[:, + visual_output.size(1):, + ...].contiguous()), + dim=1) + visual_output = visual_output + visual_output_original + elif sim_header == 'seqTransf': + # Sequential type: Transformer Encoder + visual_output_original = visual_output + seq_length = visual_output.size(1) + position_ids = torch.arange( + seq_length, dtype=torch.long, device=visual_output.device) + position_ids = position_ids.unsqueeze(0).expand( + visual_output.size(0), -1) + frame_position_embeddings = self.frame_position_embeddings( + position_ids) + visual_output = visual_output + frame_position_embeddings + + extended_video_mask = (1.0 - video_mask.unsqueeze(1)) * -1000000.0 + extended_video_mask = extended_video_mask.expand( + -1, video_mask.size(1), -1) + visual_output = visual_output.permute(1, 0, 2) # NLD -> LND + visual_output = self.transformerClip(visual_output, + extended_video_mask) + visual_output = visual_output.permute(1, 0, 2) # LND -> NLD + visual_output = visual_output + visual_output_original + + return visual_output + + def wti_interaction(self, text_feat, video_feat, text_mask, video_mask): + text_weight = self.text_weight_fc(text_feat).squeeze( + 2) # B x N_t x D -> B x N_t + text_weight.masked_fill_( + torch.tensor((1 - text_mask), dtype=torch.bool), float('-inf')) + text_weight = torch.softmax(text_weight, dim=-1) # B x N_t + + video_weight = self.video_weight_fc(video_feat).squeeze( + 2) # B x N_v x D -> B x N_v + video_weight.masked_fill_( + torch.tensor((1 - video_mask), dtype=torch.bool), float('-inf')) + video_weight = torch.softmax(video_weight, dim=-1) # B x N_v + + text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True) + video_feat = video_feat / video_feat.norm(dim=-1, keepdim=True) + + retrieve_logits = torch.einsum('atd,bvd->abtv', + [text_feat, video_feat]) + retrieve_logits = torch.einsum('abtv,at->abtv', + [retrieve_logits, text_mask]) + retrieve_logits = torch.einsum('abtv,bv->abtv', + [retrieve_logits, video_mask]) + + t2v_logits, max_idx1 = retrieve_logits.max(dim=-1) # abtv -> abt + t2v_logits = torch.einsum('abt,at->ab', [t2v_logits, text_weight]) + + v2t_logits, max_idx2 = retrieve_logits.max(dim=-2) # abtv -> abv + v2t_logits = torch.einsum('abv,bv->ab', [v2t_logits, video_weight]) + retrieve_logits = (t2v_logits + v2t_logits) / 2.0 + + if self.training: + logit_scale = self.clip.logit_scale.exp() + retrieve_logits = logit_scale * retrieve_logits + + # selecet max + max_idx1 = max_idx1[torch.arange(max_idx1.shape[0]), + torch.arange(max_idx1.shape[1])] + max_idx2 = max_idx2[torch.arange(max_idx2.shape[0]), + torch.arange(max_idx2.shape[1])] + + max_t_feat = text_feat[torch.arange(max_idx2.shape[0]). + repeat_interleave(max_idx2.shape[1]), + max_idx2.flatten()].squeeze(1) + max_v_feat = video_feat[torch.arange(max_idx1.shape[0]). + repeat_interleave(max_idx1.shape[1]), + max_idx1.flatten()].squeeze(1) + + t_feat = text_feat.reshape(-1, text_feat.shape[-1]) + t_mask = text_mask.flatten().type(torch.bool) + v_feat = video_feat.reshape(-1, video_feat.shape[-1]) + v_mask = video_mask.flatten().type(torch.bool) + t_feat = t_feat[t_mask] + v_feat = v_feat[v_mask] + max_t_feat = max_t_feat[v_mask] + max_v_feat = max_v_feat[t_mask] + text_weight = text_weight.flatten()[t_mask] + video_weight = video_weight.flatten()[v_mask] + + z_a_norm = (t_feat - t_feat.mean(0)) / t_feat.std(0) # (BxN_t)xD + z_b_norm = (max_v_feat - max_v_feat.mean(0)) / max_v_feat.std( + 0) # (BxN_t)xD + + x_a_norm = (v_feat - v_feat.mean(0)) / v_feat.std(0) # (BxN_v)xD + x_b_norm = (max_t_feat - max_t_feat.mean(0)) / max_t_feat.std( + 0) # (BxN_v)xD + + # cross-correlation matrix + N, D = z_a_norm.shape + B = text_feat.shape[0] + c1 = torch.einsum('acd,a->cd', + torch.einsum('ac,ad->acd', z_a_norm, z_b_norm), + text_weight) / B # DxD + c2 = torch.einsum('acd,a->cd', + torch.einsum('ac,ad->acd', x_a_norm, x_b_norm), + video_weight) / B # DxD + c = (c1 + c2) / 2.0 + # loss + on_diag = torch.diagonal(c).add_(-1).pow_(2).sum() + off_diag = c.flatten()[1:].view(D - 1, D + 1)[:, :-1].pow_(2).sum() + cdcr_loss = ( + on_diag * self.config.cdcr_alpha1 + + off_diag * self.config.cdcr_alpha2) + return retrieve_logits, retrieve_logits.T, cdcr_loss + else: + return retrieve_logits, retrieve_logits.T + + def _loose_similarity(self, + sequence_output, + visual_output, + attention_mask, + video_mask, + sim_header='seqTransf'): + sequence_output, visual_output = sequence_output.contiguous( + ), visual_output.contiguous() + + visual_output = self.agg_video_feat(visual_output, video_mask, + sim_header) + + if self.training: # batch merge here + visual_output = allgather(visual_output, self.config) + attention_mask = allgather(attention_mask, self.config) + video_mask = allgather(video_mask, self.config) + sequence_output = allgather(sequence_output, self.config) + torch.distributed.barrier() # force sync + + return self.wti_interaction(sequence_output, visual_output, + attention_mask, video_mask) + + def get_similarity_logits(self, + sequence_output, + visual_output, + attention_mask, + video_mask, + shaped=False, + loose_type=False): + if shaped is False: + attention_mask = attention_mask.view(-1, attention_mask.shape[-1]) + video_mask = video_mask.view(-1, video_mask.shape[-1]) + + if loose_type: + assert self.sim_header in ['meanP', 'seqLSTM', 'seqTransf'] + + if self.training: + retrieve_logits1, retrieve_logits2, barlow_loss = self._loose_similarity( + sequence_output, + visual_output, + attention_mask, + video_mask, + sim_header=self.sim_header) + return retrieve_logits1, retrieve_logits2, barlow_loss + else: + retrieve_logits1, retrieve_logits2 = self._loose_similarity( + sequence_output, + visual_output, + attention_mask, + video_mask, + sim_header=self.sim_header) + return retrieve_logits1, retrieve_logits2 + else: + raise NotImplementedError + + @property + def dtype(self): + """ + :obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). + """ + try: + return next(self.parameters()).dtype + except StopIteration: + # For nn.DataParallel compatibility in PyTorch 1.5 + def find_tensor_attributes(module: nn.Module): + tuples = [(k, v) for k, v in module.__dict__.items() + if torch.is_tensor(v)] + return tuples + + gen = self._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].dtype + + def init_weights(self, module): + """ Initialize the weights. + """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=0.02) + elif isinstance(module, LayerNorm): + if 'beta' in dir(module) and 'gamma' in dir(module): + module.beta.data.zero_() + module.gamma.data.fill_(1.0) + else: + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() diff --git a/modelscope/models/multi_modal/mmr/models/module_clip.py b/modelscope/models/multi_modal/mmr/models/module_clip.py new file mode 100644 index 00000000..36e56196 --- /dev/null +++ b/modelscope/models/multi_modal/mmr/models/module_clip.py @@ -0,0 +1,526 @@ +# Part of the implementation is borrowed and modified from The OpenAI CLIP project. + +import hashlib +import os +import urllib +import warnings +from collections import OrderedDict +from typing import Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from torch import nn +from tqdm import tqdm + +_MODELS = {} +_PT_NAME = {'ViT-B/16': 'ViT-B-16.pt'} + + +def available_models(): + """Returns the names of available CLIP models""" + return list(_MODELS.keys()) + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super(Bottleneck, self).__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + + self.relu = nn.ReLU(inplace=True) + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential( + OrderedDict([('-1', nn.AvgPool2d(stride)), + ('0', + nn.Conv2d( + inplanes, + planes * self.expansion, + 1, + stride=1, + bias=False)), + ('1', nn.BatchNorm2d(planes * self.expansion))])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu(self.bn1(self.conv1(x))) + out = self.relu(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + return out + + +class AttentionPool2d(nn.Module): + + def __init__(self, + spacial_dim: int, + embed_dim: int, + num_heads: int, + output_dim: int = None): + super(AttentionPool2d, self).__init__() + self.positional_embedding = nn.Parameter( + torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], + x.shape[2] * x.shape[3]).permute(2, 0, + 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, + key=x, + value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat( + [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False) + + return x[0] + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, + layers, + output_dim, + heads, + input_resolution=224, + width=64): + super(ModifiedResNet, self).__init__() + self.output_dim = output_dim + self.input_resolution = input_resolution + + # the 3-layer stem + self.conv1 = nn.Conv2d( + 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.conv2 = nn.Conv2d( + width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.conv3 = nn.Conv2d( + width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.avgpool = nn.AvgPool2d(2) + self.relu = nn.ReLU(inplace=True) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, + heads, output_dim) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + + def stem(x): + for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), + (self.conv3, self.bn3)]: + x = self.relu(bn(conv(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + + def __init__(self, d_model: int, n_head: int, attn_mask=None): + super(ResidualAttentionBlock, self).__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential( + OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)), + ('gelu', QuickGELU()), + ('c_proj', nn.Linear(d_model * 4, d_model))])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + attn_mask_ = self.attn_mask + if self.attn_mask is not None and hasattr(self.attn_mask, '__call__'): + attn_mask_ = self.attn_mask(x.size(0)) # LND + + attn_mask_ = attn_mask_.to( + dtype=x.dtype, device=x.device) if attn_mask_ is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask_)[0] + + def forward(self, x): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + + def __init__(self, + width: int, + layers: int, + heads: int, + attn_mask=None, + use_gc=0): + super(Transformer, self).__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ + ResidualAttentionBlock(width, heads, attn_mask) + for _ in range(layers) + ]) + + self.use_gc = use_gc + + def forward(self, x: torch.Tensor): + if self.use_gc > 0: + for blk in self.resblocks: + x = checkpoint.checkpoint(blk, x) + return x + else: + return self.resblocks(x) + + +class VisualTransformer(nn.Module): + + def __init__(self, + input_resolution: int, + patch_size: int, + width: int, + layers: int, + heads: int, + output_dim: int, + linear_patch: str = '2d', + use_gc: int = 0): + super(VisualTransformer, self).__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + + self.conv1 = nn.Conv2d( + in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False) + + scale = width**-0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn( + (input_resolution // patch_size)**2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads, use_gc=use_gc) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + # For 3D + assert linear_patch in ['2d', '3d'] + self.linear_patch = linear_patch + if self.linear_patch == '3d': + self.conv2 = nn.Conv3d( + in_channels=3, + out_channels=width, + kernel_size=(3, patch_size, patch_size), + stride=(1, patch_size, patch_size), + padding=(1, 0, 0), + bias=False) + + def forward(self, x: torch.Tensor, video_frame=-1): + + if self.linear_patch == '3d': + assert video_frame != -1 + x_3d = x.reshape(-1, video_frame, x.shape[-3], x.shape[-2], + x.shape[-1]) + x_3d = x_3d.permute(0, 2, 1, 3, 4) + x_3d = self.conv2(x_3d) # shape = [*, width, frame, grid, grid] + x_3d = x_3d.permute(0, 2, 1, 3, + 4) # shape = [*, frame, width, grid, grid] + x = x_3d.reshape( + -1, x_3d.shape[-3], x_3d.shape[-2], + x_3d.shape[-1]).contiguous() # shape = [*, width, grid, grid] + else: + x = self.conv1(x) # shape = [*, width, grid, grid] + + x = x.reshape(x.shape[0], x.shape[1], + -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + + _x = self.class_embedding.to(x.dtype) + torch.zeros( + x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device) + x = torch.cat([_x, x], dim=1) + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + return x + + +class CLIP(nn.Module): + + def __init__( + self, + embed_dim: int, + # vision + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int, + # vision linear of patch + linear_patch: str = '2d', + use_gc: int = 0): + super(CLIP, self).__init__() + + self.context_length = context_length + + if isinstance(vision_layers, (tuple, list)): + vision_heads = vision_width * 32 // 64 + self.visual = ModifiedResNet( + layers=vision_layers, + output_dim=embed_dim, + heads=vision_heads, + input_resolution=image_resolution, + width=vision_width) + else: + vision_heads = vision_width // 64 + self.visual = VisualTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim, + linear_patch=linear_patch, + use_gc=use_gc) + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter( + torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter( + torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([])) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + if isinstance(self.visual, ModifiedResNet): + if self.visual.attnpool is not None: + std = self.visual.attnpool.c_proj.in_features**-0.5 + nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) + + for resnet_block in [ + self.visual.layer1, self.visual.layer2, self.visual.layer3, + self.visual.layer4 + ]: + for name, param in resnet_block.named_parameters(): + if name.endswith('bn3.weight'): + nn.init.zeros_(param) + + proj_std = (self.transformer.width**-0.5) * ( + (2 * self.transformer.layers)**-0.5) + attn_std = self.transformer.width**-0.5 + fc_std = (2 * self.transformer.width)**-0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_( + self.text_projection, std=self.transformer.width**-0.5) + + def build_attention_mask(self, context_length): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.zeros(context_length, context_length) + mask.fill_(float('-inf')) + mask.triu_(1) # zero out the lower diagonal + return mask + + @property + def dtype(self): + return self.visual.conv1.weight.dtype + + def encode_image(self, image, return_hidden=False): + hidden = self.visual(image.type(self.dtype)) + hidden = self.visual.ln_post(hidden) @ self.visual.proj + + x = hidden[:, 0, :] + + if return_hidden: + return x, hidden + + return x + + def encode_text(self, text, return_hidden=False, prompt=None): + x = self.token_embedding(text).type( + self.dtype) # [batch_size, n_ctx, d_model] + if prompt: + x = prompt(x) + + pos_emd = self.positional_embedding[:x.size(1), :].type(self.dtype) + x = x + pos_emd + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + hidden = self.ln_final(x).type(self.dtype) @ self.text_projection + + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = hidden[torch.arange(hidden.shape[0]), text.argmax(dim=-1)] + + if return_hidden: + return x, hidden + + return x + + def forward(self, image, text): + image_features = self.encode_image(image) + text_features = self.encode_text(text) + + # normalized features + image_features = image_features / image_features.norm( + dim=-1, keepdim=True) + text_features = text_features / text_features.norm( + dim=-1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logit_scale * text_features @ image_features.t() + + return logits_per_image, logits_per_text + + +def convert_weights(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(lay): + # l = lay + if isinstance(lay, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)): + lay.weight.data = lay.weight.data.half() + if lay.bias is not None: + lay.bias.data = lay.bias.data.half() + + if isinstance(lay, nn.MultiheadAttention): + for attr in [ + *[f'{s}_proj_weight' for s in ['in', 'q', 'k', 'v']], + 'in_proj_bias', 'bias_k', 'bias_v' + ]: + tensor = getattr(lay, attr) + if tensor is not None: + tensor.data = tensor.data.half() + + for name in ['text_projection', 'proj']: + if hasattr(lay, name): + attr = getattr(lay, name) + if attr is not None: + attr.data = attr.data.half() + + model.apply(_convert_weights_to_fp16) diff --git a/modelscope/models/multi_modal/mmr/models/module_cross.py b/modelscope/models/multi_modal/mmr/models/module_cross.py new file mode 100644 index 00000000..05edb853 --- /dev/null +++ b/modelscope/models/multi_modal/mmr/models/module_cross.py @@ -0,0 +1,100 @@ +from __future__ import absolute_import, division, print_function +import logging +from collections import OrderedDict + +import json +import torch +from torch import nn + +from .until_module import ACT2FN, LayerNorm + +logger = logging.getLogger(__name__) + + +class QuickGELU(nn.Module): + + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + + def __init__(self, d_model: int, n_head: int): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential( + OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)), + ('gelu', QuickGELU()), + ('c_proj', nn.Linear(d_model * 4, d_model))])) + self.ln_2 = LayerNorm(d_model) + self.n_head = n_head + + def attention(self, x: torch.Tensor, attn_mask: torch.Tensor): + attn_mask_ = attn_mask.repeat_interleave(self.n_head, dim=0) + return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask_)[0] + + def forward(self, para_tuple: tuple): + # x: torch.Tensor, attn_mask: torch.Tensor + x, attn_mask = para_tuple + x = x + self.attention(self.ln_1(x), attn_mask) + x = x + self.mlp(self.ln_2(x)) + return (x, attn_mask) + + +class Transformer(nn.Module): + + def __init__(self, width: int, layers: int, heads: int): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential( + *[ResidualAttentionBlock(width, heads) for _ in range(layers)]) + + def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): + return self.resblocks((x, attn_mask))[0] + + +class CrossEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings. + """ + + def __init__(self, config): + super(CrossEmbeddings, self).__init__() + + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, concat_embeddings, concat_type=None): + + _, seq_length = concat_embeddings.size(0), concat_embeddings.size(1) + position_ids = torch.arange( + seq_length, dtype=torch.long, device=concat_embeddings.device) + position_ids = position_ids.unsqueeze(0).expand( + concat_embeddings.size(0), -1) + + position_embeddings = self.position_embeddings(position_ids) + + embeddings = concat_embeddings + position_embeddings # + token_type_embeddings + embeddings = self.dropout(embeddings) + return embeddings + + +class CrossPooler(nn.Module): + + def __init__(self, config): + super(CrossPooler, self).__init__() + self.ln_pool = LayerNorm(config.hidden_size) + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = QuickGELU() + + def forward(self, hidden_states, hidden_mask): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + hidden_states = self.ln_pool(hidden_states) + pooled_output = hidden_states[:, 0] + pooled_output = self.dense(pooled_output) + pooled_output = self.activation(pooled_output) + return pooled_output diff --git a/modelscope/models/multi_modal/mmr/models/tokenization_clip.py b/modelscope/models/multi_modal/mmr/models/tokenization_clip.py new file mode 100644 index 00000000..ee60f857 --- /dev/null +++ b/modelscope/models/multi_modal/mmr/models/tokenization_clip.py @@ -0,0 +1,158 @@ +import gzip +import html +import os +from functools import lru_cache + +import ftfy +import regex as re + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord('!'), + ord('~') + 1)) + list(range( + ord('¡'), + ord('¬') + 1)) + list(range(ord('®'), + ord('ÿ') + 1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + + def __init__(self, model_dir): + bpe_path = '{}/bpe_simple_vocab_16e6.txt.gz'.format(model_dir) + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode('utf-8').split('\n') + merges = merges[1:49152 - 256 - 2 + 1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v + '' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = { + '<|startoftext|>': '<|startoftext|>', + '<|endoftext|>': '<|endoftext|>' + } + self.pat = re.compile( + r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + re.IGNORECASE) + + self.vocab = self.encoder + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + (token[-1] + '', ) + pairs = get_pairs(word) + + if not pairs: + return token + '' + + while True: + bigram = min( + pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except Exception: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[ + i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] + for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] + for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode( + 'utf-8', errors='replace').replace('', ' ') + return text + + def tokenize(self, text): + tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] + for b in token.encode('utf-8')) + tokens.extend( + bpe_token for bpe_token in self.bpe(token).split(' ')) + return tokens + + def convert_tokens_to_ids(self, tokens): + return [self.encoder[bpe_token] for bpe_token in tokens] diff --git a/modelscope/models/multi_modal/mmr/models/until_module.py b/modelscope/models/multi_modal/mmr/models/until_module.py new file mode 100644 index 00000000..24e886b0 --- /dev/null +++ b/modelscope/models/multi_modal/mmr/models/until_module.py @@ -0,0 +1,120 @@ +# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BERT model.""" + +import logging +import math + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +logger = logging.getLogger(__name__) + + +def gelu(x): + """Implementation of the gelu activation function. + For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): + 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + """ + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + + +def swish(x): + return x * torch.sigmoid(x) + + +ACT2FN = {'gelu': gelu, 'relu': torch.nn.functional.relu, 'swish': swish} + + +class LayerNorm(nn.Module): + + def __init__(self, hidden_size, eps=1e-12): + """Construct a layernorm module in the TF style (epsilon inside the square root). + """ + super(LayerNorm, self).__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.bias = nn.Parameter(torch.zeros(hidden_size)) + self.variance_epsilon = eps + + def forward(self, x): + u = x.mean(-1, keepdim=True) + s = (x - u).pow(2).mean(-1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.variance_epsilon) + return self.weight * x + self.bias + + +class CrossEn(nn.Module): + + def __init__(self, config=None): + super(CrossEn, self).__init__() + + def forward(self, sim_matrix): + logpt = F.log_softmax(sim_matrix, dim=-1) + logpt = torch.diag(logpt) + nce_loss = -logpt + sim_loss = nce_loss.mean() + return sim_loss + + +class AllGather(torch.autograd.Function): + """An autograd function that performs allgather on a tensor.""" + + @staticmethod + def forward(ctx, tensor, args): + if args.world_size == 1: + ctx.rank = args.local_rank + ctx.batch_size = tensor.shape[0] + return tensor + else: + output = [torch.empty_like(tensor) for _ in range(args.world_size)] + torch.distributed.all_gather(output, tensor) + ctx.rank = args.local_rank + ctx.batch_size = tensor.shape[0] + return torch.cat(output, dim=0) + + @staticmethod + def backward(ctx, grad_output): + return ( + grad_output[ctx.batch_size * ctx.rank:ctx.batch_size + * (ctx.rank + 1)], + None, + ) + + +class AllGather2(torch.autograd.Function): + """An autograd function that performs allgather on a tensor.""" + # https://github.com/PyTorchLightning/lightning-bolts/blob/8d3fbf7782e3d3937ab8a1775a7092d7567f2933/pl_bolts/models/self_supervised/simclr/simclr_module.py#L20 + @staticmethod + def forward(ctx, tensor, args): + if args.world_size == 1: + ctx.rank = args.local_rank + ctx.batch_size = tensor.shape[0] + return tensor + else: + output = [torch.empty_like(tensor) for _ in range(args.world_size)] + torch.distributed.all_gather(output, tensor) + ctx.rank = args.local_rank + ctx.batch_size = tensor.shape[0] + return torch.cat(output, dim=0) + + @staticmethod + def backward(ctx, grad_output): + grad_input = grad_output.clone() + torch.distributed.all_reduce( + grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False) + return (grad_input[ctx.rank * ctx.batch_size:(ctx.rank + 1) + * ctx.batch_size], None) diff --git a/modelscope/models/nlp/bert_for_sequence_classification.py b/modelscope/models/nlp/bert_for_sequence_classification.py index 1a843129..530ba786 100644 --- a/modelscope/models/nlp/bert_for_sequence_classification.py +++ b/modelscope/models/nlp/bert_for_sequence_classification.py @@ -25,9 +25,9 @@ class BertForSequenceClassification(Model): """ super().__init__(model_dir, *args, **kwargs) + import torch from easynlp.appzoo import SequenceClassification from easynlp.core.predictor import get_model_predictor - import torch self.model = get_model_predictor( model_dir=self.model_dir, model_cls=SequenceClassification, diff --git a/modelscope/models/nlp/palm_for_text_generation.py b/modelscope/models/nlp/palm_for_text_generation.py index 9e387e87..245a5fdb 100644 --- a/modelscope/models/nlp/palm_for_text_generation.py +++ b/modelscope/models/nlp/palm_for_text_generation.py @@ -21,7 +21,8 @@ class PalmForTextGeneration(TorchModel): """ super().__init__(model_dir, *args, **kwargs) - from sofa.models.palm_v2 import PalmForConditionalGeneration, Translator + from sofa.models.palm_v2 import (PalmForConditionalGeneration, + Translator) self.model = PalmForConditionalGeneration.from_pretrained(model_dir) self.tokenizer = self.model.tokenizer self.generator = Translator(self.model) diff --git a/modelscope/models/nlp/space_for_dialog_intent_prediction.py b/modelscope/models/nlp/space_for_dialog_intent_prediction.py index 247d0cc7..da11c52f 100644 --- a/modelscope/models/nlp/space_for_dialog_intent_prediction.py +++ b/modelscope/models/nlp/space_for_dialog_intent_prediction.py @@ -27,7 +27,8 @@ class SpaceForDialogIntent(Model): """ super().__init__(model_dir, *args, **kwargs) - from modelscope.trainers.nlp.space.trainer.intent_trainer import IntentTrainer + from modelscope.trainers.nlp.space.trainer.intent_trainer import \ + IntentTrainer self.model_dir = model_dir self.config = kwargs.pop( 'config', diff --git a/modelscope/models/nlp/space_for_dialog_state_tracking.py b/modelscope/models/nlp/space_for_dialog_state_tracking.py index 2587d2fd..7cfb1c54 100644 --- a/modelscope/models/nlp/space_for_dialog_state_tracking.py +++ b/modelscope/models/nlp/space_for_dialog_state_tracking.py @@ -22,7 +22,7 @@ class SpaceForDialogStateTracking(Model): super().__init__(model_dir, *args, **kwargs) - from sofa.models.space import SpaceForDST, SpaceConfig + from sofa.models.space import SpaceConfig, SpaceForDST self.model_dir = model_dir self.config = SpaceConfig.from_pretrained(self.model_dir) diff --git a/modelscope/msdatasets/ms_dataset.py b/modelscope/msdatasets/ms_dataset.py index 80fdb8d8..efe624cb 100644 --- a/modelscope/msdatasets/ms_dataset.py +++ b/modelscope/msdatasets/ms_dataset.py @@ -225,8 +225,8 @@ class MsDataset: continue retained_columns.append(k) - import torch import math + import torch class MsIterableDataset(torch.utils.data.IterableDataset): diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 224d6379..8d9ed1da 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -74,6 +74,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { Tasks.text_to_image_synthesis: (Pipelines.text_to_image_synthesis, 'damo/cv_imagen_text-to-image-synthesis_tiny'), + Tasks.video_multi_modal_embedding: + (Pipelines.video_multi_modal_embedding, + 'damo/multi_modal_clip_vtretrival_msrvtt_53'), Tasks.image_color_enhance: (Pipelines.image_color_enhance, 'damo/cv_csrnet_image-color-enhance-models'), Tasks.virtual_tryon: (Pipelines.virtual_tryon, diff --git a/modelscope/pipelines/multi_modal/__init__.py b/modelscope/pipelines/multi_modal/__init__.py index 0f3c0444..26186112 100644 --- a/modelscope/pipelines/multi_modal/__init__.py +++ b/modelscope/pipelines/multi_modal/__init__.py @@ -3,7 +3,10 @@ try: from .multi_modal_embedding_pipeline import MultiModalEmbeddingPipeline from .generative_multi_modal_embedding_pipeline import GEMMMultiModalEmbeddingPipeline from .text_to_image_synthesis_pipeline import TextToImageSynthesisPipeline - from .visual_question_answering_pipeline import VisualQuestionAnsweringPipeline + from .video_multi_modal_embedding_pipeline import \ + VideoMultiModalEmbeddingPipeline + from .visual_question_answering_pipeline import \ + VisualQuestionAnsweringPipeline except ModuleNotFoundError as e: if str(e) == "No module named 'torch'": pass diff --git a/modelscope/pipelines/multi_modal/video_multi_modal_embedding_pipeline.py b/modelscope/pipelines/multi_modal/video_multi_modal_embedding_pipeline.py new file mode 100644 index 00000000..627c5ce6 --- /dev/null +++ b/modelscope/pipelines/multi_modal/video_multi_modal_embedding_pipeline.py @@ -0,0 +1,42 @@ +from typing import Any, Dict + +import torch + +from modelscope.metainfo import Pipelines +from modelscope.pipelines.base import Input +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger +from ..base import Model, Pipeline +from ..builder import PIPELINES + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.video_multi_modal_embedding, + module_name=Pipelines.video_multi_modal_embedding) +class VideoMultiModalEmbeddingPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a video_multi_modal_embedding pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model) + + def preprocess(self, input: Input) -> Dict[str, Any]: + return input + + def _process_single(self, input: Input, *args, **kwargs) -> Dict[str, Any]: + with self.place_device(): + out = self.forward(input) + + self._check_output(out) + return out + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + return self.model(input) + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/nlp/__init__.py b/modelscope/pipelines/nlp/__init__.py index 59e93dee..c97a9a10 100644 --- a/modelscope/pipelines/nlp/__init__.py +++ b/modelscope/pipelines/nlp/__init__.py @@ -15,6 +15,7 @@ try: from .dialog_modeling_pipeline import * # noqa F403 from .dialog_state_tracking_pipeline import * # noqa F403 from .fill_mask_pipeline import * # noqa F403 + from .named_entity_recognition_pipeline import * # noqa F403 from .nli_pipeline import * # noqa F403 from .sentence_similarity_pipeline import * # noqa F403 from .sentiment_classification_pipeline import * # noqa F403 @@ -22,7 +23,6 @@ try: from .text_generation_pipeline import * # noqa F403 from .word_segmentation_pipeline import * # noqa F403 from .zero_shot_classification_pipeline import * # noqa F403 - from .named_entity_recognition_pipeline import * # noqa F403 except ModuleNotFoundError as e: if str(e) == "No module named 'torch'": pass diff --git a/modelscope/preprocessors/space/dialog_state_tracking_preprocessor.py b/modelscope/preprocessors/space/dialog_state_tracking_preprocessor.py index 6ddb9a9c..ca629222 100644 --- a/modelscope/preprocessors/space/dialog_state_tracking_preprocessor.py +++ b/modelscope/preprocessors/space/dialog_state_tracking_preprocessor.py @@ -25,7 +25,7 @@ class DialogStateTrackingPreprocessor(Preprocessor): """ super().__init__(*args, **kwargs) - from sofa.models.space import SpaceTokenizer, SpaceConfig + from sofa.models.space import SpaceConfig, SpaceTokenizer self.model_dir: str = model_dir self.config = SpaceConfig.from_pretrained(self.model_dir) self.tokenizer = SpaceTokenizer.from_pretrained(self.model_dir) diff --git a/modelscope/trainers/nlp/sequence_classification_trainer.py b/modelscope/trainers/nlp/sequence_classification_trainer.py index 883110db..23d3a3f5 100644 --- a/modelscope/trainers/nlp/sequence_classification_trainer.py +++ b/modelscope/trainers/nlp/sequence_classification_trainer.py @@ -78,7 +78,8 @@ class SequenceClassificationTrainer(BaseTrainer): import torch from easynlp.appzoo import load_dataset from easynlp.appzoo.dataset import GeneralDataset - from easynlp.appzoo.sequence_classification.model import SequenceClassification + from easynlp.appzoo.sequence_classification.model import \ + SequenceClassification from easynlp.utils import losses from sklearn.metrics import f1_score from torch.utils.data import DataLoader diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 44cd87f4..220606b7 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -77,6 +77,7 @@ class MultiModalTasks(object): multi_modal_embedding = 'multi-modal-embedding' generative_multi_modal_embedding = 'generative-multi-modal-embedding' visual_question_answering = 'visual-question-answering' + video_multi_modal_embedding = 'video-multi-modal-embedding' class Tasks(CVTasks, NLPTasks, AudioTasks, MultiModalTasks): @@ -85,7 +86,6 @@ class Tasks(CVTasks, NLPTasks, AudioTasks, MultiModalTasks): Holds the standard task name to use for identifying different tasks. This should be used to register models, pipelines, trainers. """ - reverse_field_index = {} @staticmethod diff --git a/modelscope/utils/hub.py b/modelscope/utils/hub.py index dd00fd5b..5af67944 100644 --- a/modelscope/utils/hub.py +++ b/modelscope/utils/hub.py @@ -89,8 +89,8 @@ def get_model_type(model_dir): def parse_label_mapping(model_dir): - import os import json + import os label2id = None label_path = os.path.join(model_dir, ModelFile.LABEL_MAPPING) if os.path.exists(label_path): diff --git a/setup.py b/setup.py index 97778b89..8111829f 100644 --- a/setup.py +++ b/setup.py @@ -70,9 +70,9 @@ def parse_requirements(fname='requirements.txt', with_version=True): CommandLine: python -c "import setup; print(setup.parse_requirements())" """ + import re import sys from os.path import exists - import re require_fpath = fname def parse_line(line): diff --git a/tests/pipelines/test_video_multi_modal_embedding.py b/tests/pipelines/test_video_multi_modal_embedding.py new file mode 100644 index 00000000..943dbed9 --- /dev/null +++ b/tests/pipelines/test_video_multi_modal_embedding.py @@ -0,0 +1,45 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +import numpy as np + +from modelscope.models import Model +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger +from modelscope.utils.test_utils import test_level + +logger = get_logger() + + +class VideoMultiModalEmbeddingTest(unittest.TestCase): + + model_id = 'damo/multi_modal_clip_vtretrival_msrvtt_53' + video_path = 'data/test/videos/multi_modal_test_video_9770.mp4' + caption = ('a person is connecting something to system', None, None) + _input = {'video': video_path, 'text': caption} + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run(self): + pipeline_video_multi_modal_embedding = pipeline( + Tasks.video_multi_modal_embedding, model=self.model_id) + output = pipeline_video_multi_modal_embedding(self._input) + logger.info('text feature: {}'.format( + output['text_embedding'][0][0][0])) + logger.info('video feature: {}'.format( + output['video_embedding'][0][0][0])) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_default_model(self): + pipeline_video_multi_modal_embedding = pipeline( + task=Tasks.video_multi_modal_embedding) + output = pipeline_video_multi_modal_embedding(self._input) + logger.info('text feature: {}'.format( + output['text_embedding'][0][0][0])) + logger.info('video feature: {}'.format( + output['video_embedding'][0][0][0])) + + +if __name__ == '__main__': + unittest.main()