From 82e222ed14bc6703cf378ce024e41edec77ddce4 Mon Sep 17 00:00:00 2001 From: "yongfei.zyf" Date: Mon, 4 Jul 2022 14:03:24 +0800 Subject: [PATCH 1/3] [to #42322933] Add cv-action-recongnition-pipeline run inference with the cpu MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 行为识别推理同时支持CPU和GPU Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9253438 --- modelscope/pipelines/cv/action_recognition_pipeline.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/modelscope/pipelines/cv/action_recognition_pipeline.py b/modelscope/pipelines/cv/action_recognition_pipeline.py index fce037d8..8eefa301 100644 --- a/modelscope/pipelines/cv/action_recognition_pipeline.py +++ b/modelscope/pipelines/cv/action_recognition_pipeline.py @@ -32,7 +32,9 @@ class ActionRecognitionPipeline(Pipeline): config_path = osp.join(self.model, ModelFile.CONFIGURATION) logger.info(f'loading config from {config_path}') self.cfg = Config.from_file(config_path) - self.infer_model = BaseVideoModel(cfg=self.cfg).cuda() + self.device = torch.device( + 'cuda' if torch.cuda.is_available() else 'cpu') + self.infer_model = BaseVideoModel(cfg=self.cfg).to(self.device) self.infer_model.eval() self.infer_model.load_state_dict(torch.load(model_path)['model_state']) self.label_mapping = self.cfg.label_mapping @@ -40,7 +42,7 @@ class ActionRecognitionPipeline(Pipeline): def preprocess(self, input: Input) -> Dict[str, Any]: if isinstance(input, str): - video_input_data = ReadVideoData(self.cfg, input).cuda() + video_input_data = ReadVideoData(self.cfg, input).to(self.device) else: raise TypeError(f'input should be a str,' f' but got {type(input)}') From d78d944246431b4f60fdc81235e9c31246ab4f9a Mon Sep 17 00:00:00 2001 From: "yingda.chen" Date: Mon, 4 Jul 2022 14:21:22 +0800 Subject: [PATCH 2/3] [to #42322933] support text to image synthesis default model Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9253639 --- modelscope/pipelines/builder.py | 3 +++ modelscope/pipelines/cv/action_recognition_pipeline.py | 3 --- .../multi_modal/text_to_image_synthesis_pipeline.py | 4 ++-- tests/pipelines/test_action_recognition.py | 6 +----- 4 files changed, 6 insertions(+), 10 deletions(-) diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 346d8048..96043ce9 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -57,6 +57,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { Tasks.visual_question_answering: (Pipelines.visual_question_answering, 'damo/mplug_visual-question-answering_coco_large_en'), + Tasks.text_to_image_synthesis: + (Pipelines.text_to_image_synthesis, + 'damo/cv_imagen_text-to-image-synthesis_tiny') } diff --git a/modelscope/pipelines/cv/action_recognition_pipeline.py b/modelscope/pipelines/cv/action_recognition_pipeline.py index 8eefa301..757f87e3 100644 --- a/modelscope/pipelines/cv/action_recognition_pipeline.py +++ b/modelscope/pipelines/cv/action_recognition_pipeline.py @@ -2,9 +2,6 @@ import math import os.path as osp from typing import Any, Dict -import cv2 -import numpy as np -import PIL import torch from modelscope.metainfo import Pipelines diff --git a/modelscope/pipelines/multi_modal/text_to_image_synthesis_pipeline.py b/modelscope/pipelines/multi_modal/text_to_image_synthesis_pipeline.py index edffe1f2..02a34428 100644 --- a/modelscope/pipelines/multi_modal/text_to_image_synthesis_pipeline.py +++ b/modelscope/pipelines/multi_modal/text_to_image_synthesis_pipeline.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Union +from typing import Any, Dict from modelscope.metainfo import Pipelines from modelscope.pipelines.base import Input @@ -23,7 +23,7 @@ class TextToImageSynthesisPipeline(Pipeline): pipe_model = model else: raise NotImplementedError( - f'execpting a Model instance or str, but get {type(model)}.') + f'expecting a Model instance or str, but get {type(model)}.') super().__init__(model=pipe_model) diff --git a/tests/pipelines/test_action_recognition.py b/tests/pipelines/test_action_recognition.py index 6f608041..7453f136 100644 --- a/tests/pipelines/test_action_recognition.py +++ b/tests/pipelines/test_action_recognition.py @@ -1,14 +1,10 @@ # Copyright (c) Alibaba, Inc. and its affiliates. # !/usr/bin/env python import os.path as osp -import shutil import tempfile import unittest -import cv2 - from modelscope.fileio import File -from modelscope.msdatasets import MsDataset from modelscope.pipelines import pipeline from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.test_utils import test_level @@ -45,7 +41,7 @@ class ActionRecognitionTest(unittest.TestCase): print(f'recognition output: {result}.') - @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_modelhub_default_model(self): recognition_pipeline = pipeline(Tasks.action_recognition) result = recognition_pipeline( From d74e644aaf50322c73a1889585b6430e436036c4 Mon Sep 17 00:00:00 2001 From: ly103369 Date: Mon, 4 Jul 2022 16:05:52 +0800 Subject: [PATCH 3/3] [to #42322933] Add cv_r2p1d_video_embedding to maas lib Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9246462 --- modelscope/metainfo.py | 1 + .../cv/cmdssl_video_embedding/__init__.py | 3 + .../models/cv/cmdssl_video_embedding/c3d.py | 121 +++++++ .../cv/cmdssl_video_embedding/resnet2p1d.py | 339 ++++++++++++++++++ .../cv/cmdssl_video_embedding/resnet3d.py | 284 +++++++++++++++ modelscope/pipelines/builder.py | 2 + modelscope/pipelines/cv/__init__.py | 1 + .../cv/cmdssl_video_embedding_pipleline.py | 157 ++++++++ modelscope/pipelines/outputs.py | 7 + modelscope/utils/constant.py | 1 + .../pipelines/test_cmdssl_video_embedding.py | 30 ++ 11 files changed, 946 insertions(+) create mode 100644 modelscope/models/cv/cmdssl_video_embedding/__init__.py create mode 100644 modelscope/models/cv/cmdssl_video_embedding/c3d.py create mode 100644 modelscope/models/cv/cmdssl_video_embedding/resnet2p1d.py create mode 100644 modelscope/models/cv/cmdssl_video_embedding/resnet3d.py create mode 100644 modelscope/pipelines/cv/cmdssl_video_embedding_pipleline.py create mode 100644 tests/pipelines/test_cmdssl_video_embedding.py diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 520726a2..21e13252 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -49,6 +49,7 @@ class Pipelines(object): ocr_detection = 'resnet18-ocr-detection' action_recognition = 'TAdaConv_action-recognition' animal_recognation = 'resnet101-animal_recog' + cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding' # nlp tasks sentence_similarity = 'sentence-similarity' diff --git a/modelscope/models/cv/cmdssl_video_embedding/__init__.py b/modelscope/models/cv/cmdssl_video_embedding/__init__.py new file mode 100644 index 00000000..06669d2b --- /dev/null +++ b/modelscope/models/cv/cmdssl_video_embedding/__init__.py @@ -0,0 +1,3 @@ +from .c3d import C3D +from .resnet2p1d import resnet26_2p1d +from .resnet3d import resnet26_3d diff --git a/modelscope/models/cv/cmdssl_video_embedding/c3d.py b/modelscope/models/cv/cmdssl_video_embedding/c3d.py new file mode 100644 index 00000000..62f0e0b9 --- /dev/null +++ b/modelscope/models/cv/cmdssl_video_embedding/c3d.py @@ -0,0 +1,121 @@ +import torch +import torch.nn as nn + + +def conv3x3x3(in_planes, out_planes, stride=1): + return nn.Conv3d( + in_planes, out_planes, kernel_size=3, stride=stride, padding=1) + + +class C3D(nn.Module): + + def __init__(self, + num_classes=1000, + dropout=0.5, + inplanes=3, + norm_layer=None, + last_pool=True): + super(C3D, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm3d + if not last_pool and num_classes is not None: + raise ValueError('num_classes should be None when last_pool=False') + + self.conv1 = conv3x3x3(inplanes, 64) + self.bn1 = norm_layer(64) + self.relu1 = nn.ReLU(inplace=True) + self.pool1 = nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)) + + self.conv2 = conv3x3x3(64, 128) + self.bn2 = norm_layer(128) + self.relu2 = nn.ReLU(inplace=True) + self.pool2 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)) + + self.conv3a = conv3x3x3(128, 256) + self.bn3a = norm_layer(256) + self.relu3a = nn.ReLU(inplace=True) + + self.conv3b = conv3x3x3(256, 256) + self.bn3b = norm_layer(256) + self.relu3b = nn.ReLU(inplace=True) + self.pool3 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)) + + self.conv4a = conv3x3x3(256, 512) + self.bn4a = norm_layer(512) + self.relu4a = nn.ReLU(inplace=True) + + self.conv4b = conv3x3x3(512, 512) + self.bn4b = norm_layer(512) + self.relu4b = nn.ReLU(inplace=True) + self.pool4 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)) + + self.conv5a = conv3x3x3(512, 512) + self.bn5a = norm_layer(512) + self.relu5a = nn.ReLU(inplace=True) + + self.conv5b = conv3x3x3(512, 512) + self.bn5b = norm_layer(512) + self.relu5b = nn.ReLU(inplace=True) + self.pool5 = nn.AdaptiveAvgPool3d((1, 1, 1)) if last_pool else None + + if num_classes is None: + self.dropout = None + self.fc = None + else: + self.dropout = nn.Dropout(dropout) + self.fc = nn.Linear(512, num_classes) + self.out_planes = 512 + + for m in self.modules(): + if isinstance(m, nn.Conv3d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu1(x) + x = self.pool1(x) + + x = self.conv2(x) + x = self.bn2(x) + x = self.relu2(x) + x = self.pool2(x) + + x = self.conv3a(x) + x = self.bn3a(x) + x = self.relu3a(x) + + x = self.conv3b(x) + x = self.bn3b(x) + x = self.relu3b(x) + x = self.pool3(x) + + x = self.conv4a(x) + x = self.bn4a(x) + x = self.relu4a(x) + + x = self.conv4b(x) + x = self.bn4b(x) + x = self.relu4b(x) + x = self.pool4(x) + + x = self.conv5a(x) + x = self.bn5a(x) + x = self.relu5a(x) + + x = self.conv5b(x) + x = self.bn5b(x) + x = self.relu5b(x) + + if self.pool5: + x = self.pool5(x) + x = torch.flatten(x, 1) + if self.dropout and self.fc: + x = self.dropout(x) + x = self.fc(x) + + return x diff --git a/modelscope/models/cv/cmdssl_video_embedding/resnet2p1d.py b/modelscope/models/cv/cmdssl_video_embedding/resnet2p1d.py new file mode 100644 index 00000000..3b03cc74 --- /dev/null +++ b/modelscope/models/cv/cmdssl_video_embedding/resnet2p1d.py @@ -0,0 +1,339 @@ +import torch +import torch.nn as nn + + +def conv1x3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + return nn.Conv3d( + in_planes, + out_planes, + kernel_size=(1, 3, 3), + stride=(1, stride, stride), + padding=(0, dilation, dilation), + groups=groups, + bias=False, + dilation=(1, dilation, dilation)) + + +def conv3x1x1(in_planes, out_planes, stride=1, groups=1, dilation=1): + return nn.Conv3d( + in_planes, + out_planes, + kernel_size=(3, 1, 1), + stride=(stride, 1, 1), + padding=(dilation, 0, 0), + groups=groups, + bias=False, + dilation=(dilation, 1, 1)) + + +def conv1x1x1(in_planes, out_planes, stride=1): + return nn.Conv3d( + in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, + inplanes, + planes, + stride=1, + downsample=None, + groups=1, + base_width=64, + dilation=1, + norm_layer=None): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm3d + if groups != 1 or base_width != 64: + raise ValueError( + 'BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError( + 'Dilation > 1 not supported in BasicBlock') + + midplanes1 = (inplanes * planes * 3 * 3 * 3) // ( + inplanes * 3 * 3 + planes * 3) + self.conv1_s = conv1x3x3(inplanes, midplanes1, stride) + self.bn1_s = norm_layer(midplanes1) + self.conv1_t = conv3x1x1(midplanes1, planes, stride) + self.bn1_t = norm_layer(planes) + + midplanes2 = (planes * planes * 3 * 3 * 3) // ( + planes * 3 * 3 + planes * 3) + self.conv2_s = conv1x3x3(planes, midplanes2) + self.bn2_s = norm_layer(midplanes2) + self.conv2_t = conv3x1x1(midplanes2, planes) + self.bn2_t = norm_layer(planes) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1_s(x) + out = self.bn1_s(out) + out = self.relu(out) + out = self.conv1_t(out) + out = self.bn1_t(out) + out = self.relu(out) + + out = self.conv2_s(out) + out = self.bn2_s(out) + out = self.relu(out) + out = self.conv2_t(out) + out = self.bn2_t(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, + inplanes, + planes, + stride=1, + downsample=None, + groups=1, + base_width=64, + dilation=1, + norm_layer=None): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm3d + width = int(planes * (base_width / 64.)) * groups + + self.conv1 = conv1x1x1(inplanes, width) + self.bn1 = norm_layer(width) + + midplanes = (width * width * 3 * 3 * 3) // (width * 3 * 3 + width * 3) + self.conv2_s = conv1x3x3(width, midplanes, stride, groups, dilation) + self.bn2_s = norm_layer(midplanes) + self.conv2_t = conv3x1x1(midplanes, width, stride, groups, dilation) + self.bn2_t = norm_layer(width) + + self.conv3 = conv1x1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2_s(out) + out = self.bn2_s(out) + out = self.relu(out) + out = self.conv2_t(out) + out = self.bn2_t(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet2p1d(nn.Module): + + def __init__(self, + block, + layers, + num_classes=None, + zero_init_residual=True, + groups=1, + width_per_group=64, + replace_stride_with_dilation=None, + dropout=0.5, + inplanes=3, + first_stride=2, + norm_layer=None, + last_pool=True): + super(ResNet2p1d, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm3d + if not last_pool and num_classes is not None: + raise ValueError('num_classes should be None when last_pool=False') + self._norm_layer = norm_layer + self.first_stride = first_stride + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError('replace_stride_with_dilation should be None ' + 'or a 3-element tuple, got {}'.format( + replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + + midplanes = (3 * self.inplanes * 3 * 7 * 7) // (3 * 7 * 7 + + self.inplanes * 3) + self.conv1_s = nn.Conv3d( + inplanes, + midplanes, + kernel_size=(1, 7, 7), + stride=(1, first_stride, first_stride), + padding=(0, 3, 3), + bias=False) + self.bn1_s = norm_layer(midplanes) + self.conv1_t = nn.Conv3d( + midplanes, + self.inplanes, + kernel_size=(3, 1, 1), + stride=(1, 1, 1), + padding=(1, 0, 0), + bias=False) + self.bn1_t = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool3d( + kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)) + + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer( + block, + 128, + layers[1], + stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer( + block, + 256, + layers[2], + stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer( + block, + 512, + layers[3], + stride=2, + dilate=replace_stride_with_dilation[2]) + self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1)) if last_pool else None + if num_classes is None: + self.dropout = None + self.fc = None + else: + self.dropout = nn.Dropout(dropout) + self.fc = nn.Linear(512 * block.expansion, num_classes) + self.out_planes = 512 * block.expansion + + for m in self.modules(): + if isinstance(m, nn.Conv3d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2_t.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion)) + + layers = [] + layers.append( + block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block( + self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1_s(x) + x = self.bn1_s(x) + x = self.relu(x) + x = self.conv1_t(x) + x = self.bn1_t(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + if self.avgpool: + x = self.avgpool(x) + x = torch.flatten(x, 1) + if self.dropout and self.fc: + x = self.dropout(x) + x = self.fc(x) + + return x + + +def resnet10_2p1d(**kwargs): + return ResNet2p1d(BasicBlock, [1, 1, 1, 1], **kwargs) + + +def resnet18_2p1d(**kwargs): + return ResNet2p1d(BasicBlock, [2, 2, 2, 2], **kwargs) + + +def resnet26_2p1d(**kwargs): + return ResNet2p1d(Bottleneck, [2, 2, 2, 2], **kwargs) + + +def resnet34_2p1d(**kwargs): + return ResNet2p1d(BasicBlock, [3, 4, 6, 3], **kwargs) + + +def resnet50_2p1d(**kwargs): + return ResNet2p1d(Bottleneck, [3, 4, 6, 3], **kwargs) + + +def resnet101_2p1d(**kwargs): + return ResNet2p1d(Bottleneck, [3, 4, 23, 3], **kwargs) + + +def resnet152_2p1d(**kwargs): + return ResNet2p1d(Bottleneck, [3, 8, 36, 3], **kwargs) + + +def resnet200_2p1d(**kwargs): + return ResNet2p1d(Bottleneck, [3, 24, 36, 3], **kwargs) diff --git a/modelscope/models/cv/cmdssl_video_embedding/resnet3d.py b/modelscope/models/cv/cmdssl_video_embedding/resnet3d.py new file mode 100644 index 00000000..24d50a8e --- /dev/null +++ b/modelscope/models/cv/cmdssl_video_embedding/resnet3d.py @@ -0,0 +1,284 @@ +import torch +import torch.nn as nn + + +def conv3x3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + return nn.Conv3d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation) + + +def conv1x1x1(in_planes, out_planes, stride=1): + return nn.Conv3d( + in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, + inplanes, + planes, + stride=1, + downsample=None, + groups=1, + base_width=64, + dilation=1, + norm_layer=None): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm3d + if groups != 1 or base_width != 64: + raise ValueError( + 'BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError( + 'Dilation > 1 not supported in BasicBlock') + self.conv1 = conv3x3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, + inplanes, + planes, + stride=1, + downsample=None, + groups=1, + base_width=64, + dilation=1, + norm_layer=None): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm3d + width = int(planes * (base_width / 64.)) * groups + self.conv1 = conv1x1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet3d(nn.Module): + + def __init__(self, + block, + layers, + num_classes=1000, + zero_init_residual=True, + groups=1, + width_per_group=64, + replace_stride_with_dilation=None, + dropout=0.5, + inplanes=3, + first_stride=2, + norm_layer=None, + last_pool=True): + super(ResNet3d, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm3d + if not last_pool and num_classes is not None: + raise ValueError('num_classes should be None when last_pool=False') + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError('replace_stride_with_dilation should be None ' + 'or a 3-element tuple, got {}'.format( + replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv3d( + inplanes, + self.inplanes, + kernel_size=(3, 7, 7), + stride=(1, first_stride, first_stride), + padding=(1, 3, 3), + bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool3d( + kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer( + block, + 128, + layers[1], + stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer( + block, + 256, + layers[2], + stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer( + block, + 512, + layers[3], + stride=2, + dilate=replace_stride_with_dilation[2]) + self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1)) if last_pool else None + if num_classes is None: + self.dropout = None + self.fc = None + else: + self.dropout = nn.Dropout(dropout) + self.fc = nn.Linear(512 * block.expansion, num_classes) + self.out_planes = 512 * block.expansion + + for m in self.modules(): + if isinstance(m, nn.Conv3d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion)) + + layers = [] + layers.append( + block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block( + self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + if self.avgpool: + x = self.avgpool(x) + x = torch.flatten(x, 1) + if self.dropout and self.fc: + x = self.dropout(x) + x = self.fc(x) + + return x + + +def resnet10_3d(**kwargs): + return ResNet3d(BasicBlock, [1, 1, 1, 1], **kwargs) + + +def resnet18_3d(**kwargs): + return ResNet3d(BasicBlock, [2, 2, 2, 2], **kwargs) + + +def resnet26_3d(**kwargs): + return ResNet3d(Bottleneck, [2, 2, 2, 2], **kwargs) + + +def resnet34_3d(**kwargs): + return ResNet3d(BasicBlock, [3, 4, 6, 3], **kwargs) + + +def resnet50_3d(**kwargs): + return ResNet3d(Bottleneck, [3, 4, 6, 3], **kwargs) + + +def resnet101_3d(**kwargs): + return ResNet3d(Bottleneck, [3, 4, 23, 3], **kwargs) + + +def resnet152_3d(**kwargs): + return ResNet3d(Bottleneck, [3, 8, 36, 3], **kwargs) + + +def resnet200_3d(**kwargs): + return ResNet3d(Bottleneck, [3, 24, 36, 3], **kwargs) diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 96043ce9..bdf2cc17 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -57,6 +57,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { Tasks.visual_question_answering: (Pipelines.visual_question_answering, 'damo/mplug_visual-question-answering_coco_large_en'), + Tasks.video_embedding: (Pipelines.cmdssl_video_embedding, + 'damo/cv_r2p1d_video_embedding'), Tasks.text_to_image_synthesis: (Pipelines.text_to_image_synthesis, 'damo/cv_imagen_text-to-image-synthesis_tiny') diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index aa393ec5..ce769c44 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -1,6 +1,7 @@ try: from .action_recognition_pipeline import ActionRecognitionPipeline from .animal_recog_pipeline import AnimalRecogPipeline + from .cmdssl_video_embedding_pipleline import CMDSSLVideoEmbeddingPipeline except ModuleNotFoundError as e: if str(e) == "No module named 'torch'": pass diff --git a/modelscope/pipelines/cv/cmdssl_video_embedding_pipleline.py b/modelscope/pipelines/cv/cmdssl_video_embedding_pipleline.py new file mode 100644 index 00000000..c3a73bc6 --- /dev/null +++ b/modelscope/pipelines/cv/cmdssl_video_embedding_pipleline.py @@ -0,0 +1,157 @@ +import math +import os.path as osp +from typing import Any, Dict + +import cv2 +import decord +import numpy as np +import PIL +import torch +import torchvision.transforms.functional as TF +from decord import VideoReader, cpu +from PIL import Image + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.cmdssl_video_embedding.resnet2p1d import \ + resnet26_2p1d +from modelscope.pipelines.base import Input +from modelscope.pipelines.outputs import OutputKeys +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger +from ..base import Pipeline +from ..builder import PIPELINES + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.video_embedding, module_name=Pipelines.cmdssl_video_embedding) +class CMDSSLVideoEmbeddingPipeline(Pipeline): + + def __init__(self, model: str): + super().__init__(model=model) + model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE) + logger.info(f'loading model from {model_path}') + config_path = osp.join(self.model, ModelFile.CONFIGURATION) + logger.info(f'loading config from {config_path}') + self.cfg = Config.from_file(config_path) + self.model = resnet26_2p1d(num_classes=None, last_pool=True) + + if torch.cuda.is_available(): + self._device = torch.device('cuda') + else: + self._device = torch.device('cpu') + self.model = self.model.to(self._device).eval().requires_grad_(False) + self.model.load_state_dict(torch.load(model_path)) + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + decord.bridge.set_bridge('native') + + transforms = VCompose([ + VRescale(size=self.cfg.DATA.scale_size), + VCenterCrop(size=self.cfg.DATA.crop_size), + VToTensor(), + VNormalize(mean=self.cfg.DATA.mean, std=self.cfg.DATA.std) + ]) + + clip_len = (self.cfg.DATA.video_frames + - 1) * self.cfg.DATA.video_stride + 1 + vr = VideoReader(input, ctx=cpu(0)) + if len(vr) <= clip_len: + init_frames = np.zeros(self.cfg.DATA.multi_crop, dtype=int) + else: + init_frames = np.linspace(0, + len(vr) - clip_len, + self.cfg.DATA.multi_crop + 1) + init_frames = ((init_frames[1:] + init_frames[:-1]) + / 2.).astype(int) + + indices = np.arange(0, clip_len, self.cfg.DATA.video_stride) + indices = (init_frames[:, None] + indices[None, :]).reshape(-1) + indices[indices >= len(vr)] = 0 + + frames = torch.from_numpy(vr.get_batch(indices).asnumpy()).chunk( + self.cfg.DATA.multi_crop, dim=0) + frames = [ + transforms([Image.fromarray(f) for f in u.numpy()]) for u in frames + ] + frames = torch.stack(frames, dim=0) + result = {'video_data': frames} + return result + + @torch.no_grad() + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + frames = input['video_data'].to(self._device) + feature = self.model(frames) + feature = feature.mean(0) + return {OutputKeys.VIDEO_EMBEDDING: feature.data.cpu().numpy()} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs + + +class VCompose(object): + + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, item): + for t in self.transforms: + item = t(item) + return item + + +class VRescale(object): + + def __init__(self, size=128): + self.size = size + + def __call__(self, vclip): + w, h = vclip[0].size + scale = self.size / min(w, h) + out_w, out_h = int(round(w * scale)), int(round(h * scale)) + vclip = [u.resize((out_w, out_h), Image.BILINEAR) for u in vclip] + return vclip + + +class VCenterCrop(object): + + def __init__(self, size=112): + self.size = size + + def __call__(self, vclip): + w, h = vclip[0].size + assert min(w, h) >= self.size + x1 = (w - self.size) // 2 + y1 = (h - self.size) // 2 + vclip = [ + u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in vclip + ] + return vclip + + +class VToTensor(object): + + def __call__(self, vclip): + vclip = torch.stack([TF.to_tensor(u) for u in vclip], dim=1) + return vclip + + +class VNormalize(object): + + def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): + self.mean = mean + self.std = std + + def __call__(self, vclip): + assert vclip.min() > -0.1 and vclip.max() < 1.1, \ + 'vclip values should be in [0, 1]' + vclip = vclip.clone() + if not isinstance(self.mean, torch.Tensor): + self.mean = vclip.new_tensor(self.mean).view(-1, 1, 1, 1) + if not isinstance(self.std, torch.Tensor): + self.std = vclip.new_tensor(self.std).view(-1, 1, 1, 1) + vclip.sub_(self.mean).div_(self.std) + return vclip diff --git a/modelscope/pipelines/outputs.py b/modelscope/pipelines/outputs.py index 1468baa5..b418fe7f 100644 --- a/modelscope/pipelines/outputs.py +++ b/modelscope/pipelines/outputs.py @@ -21,6 +21,7 @@ class OutputKeys(object): TRANSLATION = 'translation' RESPONSE = 'response' PREDICTION = 'prediction' + VIDEO_EMBEDDING = 'video_embedding' TASK_OUTPUTS = { @@ -90,6 +91,12 @@ TASK_OUTPUTS = { # } Tasks.ocr_detection: [OutputKeys.POLYGONS], + # video embedding result for single video + # { + # "video_embedding": np.array with shape [D], + # } + Tasks.video_embedding: [OutputKeys.VIDEO_EMBEDDING], + # ============ nlp tasks =================== # text classification result for single sample diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 69faaf6a..d4a19304 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -31,6 +31,7 @@ class Tasks(object): image_matting = 'image-matting' ocr_detection = 'ocr-detection' action_recognition = 'action-recognition' + video_embedding = 'video-embedding' # nlp tasks word_segmentation = 'word-segmentation' diff --git a/tests/pipelines/test_cmdssl_video_embedding.py b/tests/pipelines/test_cmdssl_video_embedding.py new file mode 100644 index 00000000..dd06305a --- /dev/null +++ b/tests/pipelines/test_cmdssl_video_embedding.py @@ -0,0 +1,30 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# !/usr/bin/env python +import os.path as osp +import shutil +import tempfile +import unittest + +import cv2 + +from modelscope.fileio import File +from modelscope.msdatasets import MsDataset +from modelscope.pipelines import pipeline +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.test_utils import test_level + + +class CMDSSLVideoEmbeddingTest(unittest.TestCase): + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + videossl_pipeline = pipeline( + Tasks.video_embedding, model='damo/cv_r2p1d_video_embedding') + result = videossl_pipeline( + 'data/test/videos/action_recognition_test_video.mp4') + + print(f'video embedding output: {result}.') + + +if __name__ == '__main__': + unittest.main()