diff --git a/data/test/videos/action_recognition_test_video.mp4 b/data/test/videos/action_recognition_test_video.mp4 new file mode 100644 index 00000000..9197b770 --- /dev/null +++ b/data/test/videos/action_recognition_test_video.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:24dc4237b1197321ee8486bb983fa01fd47e2b4afdb3c2df24229e5f2bd20119 +size 1475924 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index a8677c16..3f9cc64e 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -40,6 +40,7 @@ class Pipelines(object): image_matting = 'unet-image-matting' person_image_cartoon = 'unet-person-image-cartoon' ocr_detection = 'resnet18-ocr-detection' + action_recognition = 'TAdaConv_action-recognition' # nlp tasks sentence_similarity = 'sentence-similarity' diff --git a/modelscope/models/cv/action_recognition/__init__.py b/modelscope/models/cv/action_recognition/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/action_recognition/models.py b/modelscope/models/cv/action_recognition/models.py new file mode 100644 index 00000000..e85b6d81 --- /dev/null +++ b/modelscope/models/cv/action_recognition/models.py @@ -0,0 +1,91 @@ +import torch +import torch.nn as nn + +from .tada_convnext import TadaConvNeXt + + +class BaseVideoModel(nn.Module): + """ + Standard video model. + The model is divided into the backbone and the head, where the backbone + extracts features and the head performs classification. + + The backbones can be defined in model/base/backbone.py or anywhere else + as long as the backbone is registered by the BACKBONE_REGISTRY. + The heads can be defined in model/module_zoo/heads/ or anywhere else + as long as the head is registered by the HEAD_REGISTRY. + + The registries automatically finds the registered modules and construct + the base video model. + """ + + def __init__(self, cfg): + """ + Args: + cfg (Config): global config object. + """ + super(BaseVideoModel, self).__init__() + # the backbone is created according to meta-architectures + # defined in models/base/backbone.py + self.backbone = TadaConvNeXt(cfg) + + # the head is created according to the heads + # defined in models/module_zoo/heads + self.head = BaseHead(cfg) + + def forward(self, x): + x = self.backbone(x) + x = self.head(x) + return x + + +class BaseHead(nn.Module): + """ + Constructs base head. + """ + + def __init__( + self, + cfg, + ): + """ + Args: + cfg (Config): global config object. + """ + super(BaseHead, self).__init__() + self.cfg = cfg + dim = cfg.VIDEO.BACKBONE.NUM_OUT_FEATURES + num_classes = cfg.VIDEO.HEAD.NUM_CLASSES + dropout_rate = cfg.VIDEO.HEAD.DROPOUT_RATE + activation_func = cfg.VIDEO.HEAD.ACTIVATION + self._construct_head(dim, num_classes, dropout_rate, activation_func) + + def _construct_head(self, dim, num_classes, dropout_rate, activation_func): + self.global_avg_pool = nn.AdaptiveAvgPool3d(1) + + if dropout_rate > 0.0: + self.dropout = nn.Dropout(dropout_rate) + + self.out = nn.Linear(dim, num_classes, bias=True) + + if activation_func == 'softmax': + self.activation = nn.Softmax(dim=-1) + elif activation_func == 'sigmoid': + self.activation = nn.Sigmoid() + else: + raise NotImplementedError('{} is not supported as an activation' + 'function.'.format(activation_func)) + + def forward(self, x): + if len(x.shape) == 5: + x = self.global_avg_pool(x) + # (N, C, T, H, W) -> (N, T, H, W, C). + x = x.permute((0, 2, 3, 4, 1)) + if hasattr(self, 'dropout'): + out = self.dropout(x) + else: + out = x + out = self.out(out) + out = self.activation(out) + out = out.view(out.shape[0], -1) + return out, x.view(x.shape[0], -1) diff --git a/modelscope/models/cv/action_recognition/tada_convnext.py b/modelscope/models/cv/action_recognition/tada_convnext.py new file mode 100644 index 00000000..379b5271 --- /dev/null +++ b/modelscope/models/cv/action_recognition/tada_convnext.py @@ -0,0 +1,472 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.modules.utils import _pair, _triple + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """ + From https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py. + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0], ) + (1, ) * ( + x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand( + shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """ + From https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py. + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class TadaConvNeXt(nn.Module): + r""" ConvNeXt + A PyTorch impl of : `A ConvNet for the 2020s` - + https://arxiv.org/pdf/2201.03545.pdf + + Args: + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] + dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] + drop_path_rate (float): Stochastic depth rate. Default: 0. + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. + """ + + def __init__( + self, cfg + # in_chans=3, num_classes=1000, + # depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0., + # layer_scale_init_value=1e-6, head_init_scale=1., + ): + super().__init__() + in_chans = cfg.VIDEO.BACKBONE.NUM_INPUT_CHANNELS + dims = cfg.VIDEO.BACKBONE.NUM_FILTERS + drop_path_rate = cfg.VIDEO.BACKBONE.DROP_PATH + depths = cfg.VIDEO.BACKBONE.DEPTH + layer_scale_init_value = cfg.VIDEO.BACKBONE.LARGE_SCALE_INIT_VALUE + stem_t_kernel_size = cfg.VIDEO.BACKBONE.STEM.T_KERNEL_SIZE if hasattr( + cfg.VIDEO.BACKBONE.STEM, 'T_KERNEL_SIZE') else 2 + t_stride = cfg.VIDEO.BACKBONE.STEM.T_STRIDE if hasattr( + cfg.VIDEO.BACKBONE.STEM, 'T_STRIDE') else 2 + + self.downsample_layers = nn.ModuleList( + ) # stem and 3 intermediate downsampling conv layers + stem = nn.Sequential( + nn.Conv3d( + in_chans, + dims[0], + kernel_size=(stem_t_kernel_size, 4, 4), + stride=(t_stride, 4, 4), + padding=((stem_t_kernel_size - 1) // 2, 0, 0)), + LayerNorm(dims[0], eps=1e-6, data_format='channels_first')) + self.downsample_layers.append(stem) + for i in range(3): + downsample_layer = nn.Sequential( + LayerNorm(dims[i], eps=1e-6, data_format='channels_first'), + nn.Conv3d( + dims[i], + dims[i + 1], + kernel_size=(1, 2, 2), + stride=(1, 2, 2)), + ) + self.downsample_layers.append(downsample_layer) + + self.stages = nn.ModuleList( + ) # 4 feature resolution stages, each consisting of multiple residual blocks + dp_rates = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] + cur = 0 + for i in range(4): + stage = nn.Sequential(*[ + TAdaConvNeXtBlock( + cfg, + dim=dims[i], + drop_path=dp_rates[cur + j], + layer_scale_init_value=layer_scale_init_value) + for j in range(depths[i]) + ]) + self.stages.append(stage) + cur += depths[i] + + self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer + + def forward_features(self, x): + for i in range(4): + x = self.downsample_layers[i](x) + x = self.stages[i](x) + return self.norm(x.mean( + [-3, -2, -1])) # global average pooling, (N, C, H, W) -> (N, C) + + def forward(self, x): + if isinstance(x, dict): + x = x['video'] + x = self.forward_features(x) + return x + + def get_num_layers(self): + return 12, 0 + + +class ConvNeXtBlock(nn.Module): + r""" ConvNeXt Block. There are two equivalent implementations: + (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) + (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back + We use (2) as we find it slightly faster in PyTorch + + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + """ + + def __init__(self, cfg, dim, drop_path=0., layer_scale_init_value=1e-6): + super().__init__() + self.dwconv = nn.Conv3d( + dim, dim, kernel_size=(1, 7, 7), padding=(0, 3, 3), + groups=dim) # depthwise conv + self.norm = LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear( + dim, + 4 * dim) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(4 * dim, dim) + self.gamma = nn.Parameter( + layer_scale_init_value * torch.ones((dim)), + requires_grad=True) if layer_scale_init_value > 0 else None + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + input = x + x = self.dwconv(x) + x = x.permute(0, 2, 3, 4, 1) # (N, C, T, H, W) -> (N, T, H, W, C) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 4, 1, 2, 3) # (N, T, H, W, C) -> (N, C, T, H, W) + + x = input + self.drop_path(x) + return x + + +class LayerNorm(nn.Module): + r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with + shape (batch_size, height, width, channels) while channels_first corresponds to inputs + with shape (batch_size, channels, height, width). + """ + + def __init__(self, + normalized_shape, + eps=1e-6, + data_format='channels_last'): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ['channels_last', 'channels_first']: + raise NotImplementedError + self.normalized_shape = (normalized_shape, ) + + def forward(self, x): + if self.data_format == 'channels_last': + return F.layer_norm(x, self.normalized_shape, self.weight, + self.bias, self.eps) + elif self.data_format == 'channels_first': + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None, None] * x + self.bias[:, None, None, + None] + return x + + +class TAdaConvNeXtBlock(nn.Module): + r""" ConvNeXt Block. There are two equivalent implementations: + (1) DwConv -> LayerNorm (channels_fi rst) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) + (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back + We use (2) as we find it slightly faster in PyTorch + + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + """ + + def __init__(self, cfg, dim, drop_path=0., layer_scale_init_value=1e-6): + super().__init__() + layer_scale_init_value = float(layer_scale_init_value) + self.dwconv = TAdaConv2d( + dim, + dim, + kernel_size=(1, 7, 7), + padding=(0, 3, 3), + groups=dim, + cal_dim='cout') + route_func_type = cfg.VIDEO.BACKBONE.BRANCH.ROUTE_FUNC_TYPE + if route_func_type == 'normal': + self.dwconv_rf = RouteFuncMLP( + c_in=dim, + ratio=cfg.VIDEO.BACKBONE.BRANCH.ROUTE_FUNC_R, + kernels=cfg.VIDEO.BACKBONE.BRANCH.ROUTE_FUNC_K, + with_bias_cal=self.dwconv.bias is not None) + elif route_func_type == 'normal_lngelu': + self.dwconv_rf = RouteFuncMLPLnGelu( + c_in=dim, + ratio=cfg.VIDEO.BACKBONE.BRANCH.ROUTE_FUNC_R, + kernels=cfg.VIDEO.BACKBONE.BRANCH.ROUTE_FUNC_K, + with_bias_cal=self.dwconv.bias is not None) + else: + raise ValueError( + 'Unknown route_func_type: {}'.format(route_func_type)) + self.norm = LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear( + dim, + 4 * dim) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(4 * dim, dim) + self.gamma = nn.Parameter( + layer_scale_init_value * torch.ones((dim)), + requires_grad=True) if layer_scale_init_value > 0 else None + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + input = x + x = self.dwconv(x, self.dwconv_rf(x)) + x = x.permute(0, 2, 3, 4, 1) # (N, C, T, H, W) -> (N, T, H, W, C) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 4, 1, 2, 3) # (N, T, H, W, C) -> (N, C, T, H, W) + + x = input + self.drop_path(x) + return x + + +class RouteFuncMLPLnGelu(nn.Module): + """ + The routing function for generating the calibration weights. + """ + + def __init__(self, + c_in, + ratio, + kernels, + with_bias_cal=False, + bn_eps=1e-5, + bn_mmt=0.1): + """ + Args: + c_in (int): number of input channels. + ratio (int): reduction ratio for the routing function. + kernels (list): temporal kernel size of the stacked 1D convolutions + """ + super(RouteFuncMLPLnGelu, self).__init__() + self.c_in = c_in + self.with_bias_cal = with_bias_cal + self.avgpool = nn.AdaptiveAvgPool3d((None, 1, 1)) + self.globalpool = nn.AdaptiveAvgPool3d(1) + self.g = nn.Conv3d( + in_channels=c_in, + out_channels=c_in, + kernel_size=1, + padding=0, + ) + self.a = nn.Conv3d( + in_channels=c_in, + out_channels=int(c_in // ratio), + kernel_size=[kernels[0], 1, 1], + padding=[kernels[0] // 2, 0, 0], + ) + # self.bn = nn.BatchNorm3d(int(c_in//ratio), eps=bn_eps, momentum=bn_mmt) + self.ln = LayerNorm( + int(c_in // ratio), eps=1e-6, data_format='channels_first') + self.gelu = nn.GELU() + # self.relu = nn.ReLU(inplace=True) + self.b = nn.Conv3d( + in_channels=int(c_in // ratio), + out_channels=c_in, + kernel_size=[kernels[1], 1, 1], + padding=[kernels[1] // 2, 0, 0], + bias=False) + self.b.skip_init = True + self.b.weight.data.zero_() # to make sure the initial values + # for the output is 1. + if with_bias_cal: + self.b_bias = nn.Conv3d( + in_channels=int(c_in // ratio), + out_channels=c_in, + kernel_size=[kernels[1], 1, 1], + padding=[kernels[1] // 2, 0, 0], + bias=False) + self.b_bias.skip_init = True + self.b_bias.weight.data.zero_() # to make sure the initial values + # for the output is 1. + + def forward(self, x): + g = self.globalpool(x) + x = self.avgpool(x) + x = self.a(x + self.g(g)) + # x = self.bn(x) + # x = self.relu(x) + x = self.ln(x) + x = self.gelu(x) + if self.with_bias_cal: + return [self.b(x) + 1, self.b_bias(x) + 1] + else: + return self.b(x) + 1 + + +class TAdaConv2d(nn.Module): + """ + Performs temporally adaptive 2D convolution. + Currently, only application on 5D tensors is supported, which makes TAdaConv2d + essentially a 3D convolution with temporal kernel size of 1. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + cal_dim='cin'): + super(TAdaConv2d, self).__init__() + """ + Args: + in_channels (int): number of input channels. + out_channels (int): number of output channels. + kernel_size (list): kernel size of TAdaConv2d. + stride (list): stride for the convolution in TAdaConv2d. + padding (list): padding for the convolution in TAdaConv2d. + dilation (list): dilation of the convolution in TAdaConv2d. + groups (int): number of groups for TAdaConv2d. + bias (bool): whether to use bias in TAdaConv2d. + calibration_mode (str): calibrated dimension in TAdaConv2d. + Supported input "cin", "cout". + """ + + kernel_size = _triple(kernel_size) + stride = _triple(stride) + padding = _triple(padding) + dilation = _triple(dilation) + + assert kernel_size[0] == 1 + assert stride[0] == 1 + assert padding[0] == 0 + assert dilation[0] == 1 + assert cal_dim in ['cin', 'cout'] + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.cal_dim = cal_dim + + # base weights (W_b) + self.weight = nn.Parameter( + torch.Tensor(1, 1, out_channels, in_channels // groups, + kernel_size[1], kernel_size[2])) + if bias: + self.bias = nn.Parameter(torch.Tensor(1, 1, out_channels)) + else: + self.register_parameter('bias', None) + + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) + nn.init.uniform_(self.bias, -bound, bound) + + def forward(self, x, alpha): + """ + Args: + x (tensor): feature to perform convolution on. + alpha (tensor): calibration weight for the base weights. + W_t = alpha_t * W_b + """ + if isinstance(alpha, list): + w_alpha, b_alpha = alpha[0], alpha[1] + else: + w_alpha = alpha + b_alpha = None + _, _, c_out, c_in, kh, kw = self.weight.size() + b, c_in, t, h, w = x.size() + x = x.permute(0, 2, 1, 3, 4).reshape(1, -1, h, w) + + if self.cal_dim == 'cin': + # w_alpha: B, C, T, H(1), W(1) -> B, T, C, H(1), W(1) -> B, T, 1, C, H(1), W(1) + # corresponding to calibrating the input channel + weight = (w_alpha.permute(0, 2, 1, 3, 4).unsqueeze(2) + * self.weight).reshape(-1, c_in // self.groups, kh, kw) + elif self.cal_dim == 'cout': + # w_alpha: B, C, T, H(1), W(1) -> B, T, C, H(1), W(1) -> B, T, C, 1, H(1), W(1) + # corresponding to calibrating the input channel + weight = (w_alpha.permute(0, 2, 1, 3, 4).unsqueeze(3) + * self.weight).reshape(-1, c_in // self.groups, kh, kw) + + bias = None + if self.bias is not None: + if b_alpha is not None: + # b_alpha: B, C, T, H(1), W(1) -> B, T, C, H(1), W(1) -> B, T, C + bias = (b_alpha.permute(0, 2, 1, 3, 4).squeeze() + * self.bias).reshape(-1) + else: + bias = self.bias.repeat(b, t, 1).reshape(-1) + output = F.conv2d( + x, + weight=weight, + bias=bias, + stride=self.stride[1:], + padding=self.padding[1:], + dilation=self.dilation[1:], + groups=self.groups * b * t) + + output = output.view(b, t, c_out, output.size(-2), + output.size(-1)).permute(0, 2, 1, 3, 4) + + return output + + def __repr__(self): + return f'TAdaConv2d({self.in_channels}, {self.out_channels}, kernel_size={self.kernel_size}, ' +\ + f"stride={self.stride}, padding={self.padding}, bias={self.bias is not None}, cal_dim=\"{self.cal_dim}\")" diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 40b09edc..22274b55 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -49,6 +49,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { Tasks.ocr_detection: (Pipelines.ocr_detection, 'damo/cv_resnet18_ocr-detection-line-level_damo'), Tasks.fill_mask: (Pipelines.fill_mask, 'damo/nlp_veco_fill-mask_large') + Tasks.action_recognition: (Pipelines.action_recognition, + 'damo/cv_TAdaConv_action-recognition'), } diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index 767c90d7..68d875ec 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -1,3 +1,4 @@ +from .action_recognition_pipeline import ActionRecognitionPipeline from .image_cartoon_pipeline import ImageCartoonPipeline from .image_matting_pipeline import ImageMattingPipeline from .ocr_detection_pipeline import OCRDetectionPipeline diff --git a/modelscope/pipelines/cv/action_recognition_pipeline.py b/modelscope/pipelines/cv/action_recognition_pipeline.py new file mode 100644 index 00000000..845f8f9a --- /dev/null +++ b/modelscope/pipelines/cv/action_recognition_pipeline.py @@ -0,0 +1,65 @@ +import math +import os.path as osp +from typing import Any, Dict + +import cv2 +import numpy as np +import PIL +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.action_recognition.models import BaseVideoModel +from modelscope.pipelines.base import Input +from modelscope.preprocessors.video import ReadVideoData +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger +from ..base import Pipeline +from ..builder import PIPELINES + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.action_recognition, module_name=Pipelines.action_recognition) +class ActionRecognitionPipeline(Pipeline): + + def __init__(self, model: str): + super().__init__(model=model) + model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE) + logger.info(f'loading model from {model_path}') + config_path = osp.join(self.model, ModelFile.CONFIGURATION) + logger.info(f'loading config from {config_path}') + self.cfg = Config.from_file(config_path) + self.infer_model = BaseVideoModel(cfg=self.cfg).cuda() + self.infer_model.eval() + self.infer_model.load_state_dict(torch.load(model_path)['model_state']) + self.label_mapping = self.cfg.label_mapping + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + if isinstance(input, str): + video_input_data = ReadVideoData(self.cfg, input).cuda() + else: + raise TypeError(f'input should be a str,' + f' but got {type(input)}') + result = {'video_data': video_input_data} + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + pred = self.perform_inference(input['video_data']) + output_label = self.label_mapping[str(pred)] + return {'output_label': output_label} + + @torch.no_grad() + def perform_inference(self, data, max_bsz=4): + iter_num = math.ceil(data.size(0) / max_bsz) + preds_list = [] + for i in range(iter_num): + preds_list.append( + self.infer_model(data[i * max_bsz:(i + 1) * max_bsz])[0]) + pred = torch.cat(preds_list, dim=0) + return pred.mean(dim=0).argmax().item() + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/outputs.py b/modelscope/pipelines/outputs.py index 6140f726..3b1c67de 100644 --- a/modelscope/pipelines/outputs.py +++ b/modelscope/pipelines/outputs.py @@ -45,6 +45,12 @@ TASK_OUTPUTS = { Tasks.image_matting: ['output_png'], Tasks.image_generation: ['output_png'], + # action recognition result for single video + # { + # "output_label": "abseiling" + # } + Tasks.action_recognition: ['output_label'], + # pose estimation result for single sample # { # "poses": np.array with shape [num_pose, num_keypoint, 3], diff --git a/modelscope/preprocessors/__init__.py b/modelscope/preprocessors/__init__.py index a94cbca1..8ff5f935 100644 --- a/modelscope/preprocessors/__init__.py +++ b/modelscope/preprocessors/__init__.py @@ -5,7 +5,7 @@ from .base import Preprocessor from .builder import PREPROCESSORS, build_preprocessor from .common import Compose from .image import LoadImage, load_image -from .multi_model import OfaImageCaptionPreprocessor +from .multi_modal import OfaImageCaptionPreprocessor from .nlp import * # noqa F403 from .space.dialog_intent_prediction_preprocessor import * # noqa F403 from .space.dialog_modeling_preprocessor import * # noqa F403 diff --git a/modelscope/preprocessors/multi_model.py b/modelscope/preprocessors/multi_modal.py similarity index 95% rename from modelscope/preprocessors/multi_model.py rename to modelscope/preprocessors/multi_modal.py index aa0bc8a7..7c8f0fab 100644 --- a/modelscope/preprocessors/multi_model.py +++ b/modelscope/preprocessors/multi_modal.py @@ -73,7 +73,7 @@ class OfaImageCaptionPreprocessor(Preprocessor): self.eos_item = torch.LongTensor([task.src_dict.eos()]) self.pad_idx = task.src_dict.pad() - @type_assert(object, (str, tuple)) + @type_assert(object, (str, tuple, Image.Image)) def __call__(self, data: Union[str, tuple]) -> Dict[str, Any]: def encode_text(text, length=None, append_bos=False, append_eos=False): @@ -89,8 +89,8 @@ class OfaImageCaptionPreprocessor(Preprocessor): s = torch.cat([s, self.eos_item]) return s - if isinstance(input, Image.Image): - patch_image = self.patch_resize_transform(input).unsqueeze(0) + if isinstance(data, Image.Image): + patch_image = self.patch_resize_transform(data).unsqueeze(0) else: patch_image = self.patch_resize_transform( load_image(data)).unsqueeze(0) diff --git a/modelscope/preprocessors/video.py b/modelscope/preprocessors/video.py new file mode 100644 index 00000000..262fdaa5 --- /dev/null +++ b/modelscope/preprocessors/video.py @@ -0,0 +1,232 @@ +import math +import os +import random + +import decord +import numpy as np +import torch +import torch.nn as nn +import torch.utils.data +import torch.utils.dlpack as dlpack +import torchvision.transforms._transforms_video as transforms +from decord import VideoReader +from torchvision.transforms import Compose + + +def ReadVideoData(cfg, video_path): + """ simple interface to load video frames from file + + Args: + cfg (Config): The global config object. + video_path (str): video file path + """ + data = _decode_video(cfg, video_path) + transform = kinetics400_tranform(cfg) + data_list = [] + for i in range(data.size(0)): + for j in range(cfg.TEST.NUM_SPATIAL_CROPS): + transform.transforms[1].set_spatial_index(j) + data_list.append(transform(data[i])) + return torch.stack(data_list, dim=0) + + +def kinetics400_tranform(cfg): + """ + Configs the transform for the kinetics-400 dataset. + We apply controlled spatial cropping and normalization. + Args: + cfg (Config): The global config object. + """ + resize_video = KineticsResizedCrop( + short_side_range=[cfg.DATA.TEST_SCALE, cfg.DATA.TEST_SCALE], + crop_size=cfg.DATA.TEST_CROP_SIZE, + num_spatial_crops=cfg.TEST.NUM_SPATIAL_CROPS) + std_transform_list = [ + transforms.ToTensorVideo(), resize_video, + transforms.NormalizeVideo( + mean=cfg.DATA.MEAN, std=cfg.DATA.STD, inplace=True) + ] + return Compose(std_transform_list) + + +def _interval_based_sampling(vid_length, vid_fps, target_fps, clip_idx, + num_clips, num_frames, interval, minus_interval): + """ + Generates the frame index list using interval based sampling. + Args: + vid_length (int): the length of the whole video (valid selection range). + vid_fps (int): the original video fps + target_fps (int): the normalized video fps + clip_idx (int): -1 for random temporal sampling, and positive values for + sampling specific clip from the video + num_clips (int): the total clips to be sampled from each video. + combined with clip_idx, the sampled video is the "clip_idx-th" + video from "num_clips" videos. + num_frames (int): number of frames in each sampled clips. + interval (int): the interval to sample each frame. + minus_interval (bool): control the end index + Returns: + index (tensor): the sampled frame indexes + """ + if num_frames == 1: + index = [random.randint(0, vid_length - 1)] + else: + # transform FPS + clip_length = num_frames * interval * vid_fps / target_fps + + max_idx = max(vid_length - clip_length, 0) + start_idx = clip_idx * math.floor(max_idx / (num_clips - 1)) + if minus_interval: + end_idx = start_idx + clip_length - interval + else: + end_idx = start_idx + clip_length - 1 + + index = torch.linspace(start_idx, end_idx, num_frames) + index = torch.clamp(index, 0, vid_length - 1).long() + + return index + + +def _decode_video_frames_list(cfg, frames_list, vid_fps): + """ + Decodes the video given the numpy frames. + Args: + cfg (Config): The global config object. + frames_list (list): all frames for a video, the frames should be numpy array. + vid_fps (int): the fps of this video. + Returns: + frames (Tensor): video tensor data + """ + assert isinstance(frames_list, list) + num_clips_per_video = cfg.TEST.NUM_ENSEMBLE_VIEWS + + frame_list = [] + for clip_idx in range(num_clips_per_video): + # for each clip in the video, + # a list is generated before decoding the specified frames from the video + list_ = _interval_based_sampling( + len(frames_list), vid_fps, cfg.DATA.TARGET_FPS, clip_idx, + num_clips_per_video, cfg.DATA.NUM_INPUT_FRAMES, + cfg.DATA.SAMPLING_RATE, cfg.DATA.MINUS_INTERVAL) + frames = None + frames = torch.from_numpy( + np.stack([frames_list[l_index] for l_index in list_.tolist()], + axis=0)) + frame_list.append(frames) + frames = torch.stack(frame_list) + if num_clips_per_video == 1: + frames = frames.squeeze(0) + + return frames + + +def _decode_video(cfg, path): + """ + Decodes the video given the numpy frames. + Args: + path (str): video file path. + Returns: + frames (Tensor): video tensor data + """ + vr = VideoReader(path) + + num_clips_per_video = cfg.TEST.NUM_ENSEMBLE_VIEWS + + frame_list = [] + for clip_idx in range(num_clips_per_video): + # for each clip in the video, + # a list is generated before decoding the specified frames from the video + list_ = _interval_based_sampling( + len(vr), vr.get_avg_fps(), cfg.DATA.TARGET_FPS, clip_idx, + num_clips_per_video, cfg.DATA.NUM_INPUT_FRAMES, + cfg.DATA.SAMPLING_RATE, cfg.DATA.MINUS_INTERVAL) + frames = None + if path.endswith('.avi'): + append_list = torch.arange(0, list_[0], 4) + frames = dlpack.from_dlpack( + vr.get_batch(torch.cat([append_list, + list_])).to_dlpack()).clone() + frames = frames[append_list.shape[0]:] + else: + frames = dlpack.from_dlpack( + vr.get_batch(list_).to_dlpack()).clone() + frame_list.append(frames) + frames = torch.stack(frame_list) + if num_clips_per_video == 1: + frames = frames.squeeze(0) + del vr + return frames + + +class KineticsResizedCrop(object): + """Perform resize and crop for kinetics-400 dataset + Args: + short_side_range (list): The length of short side range. In inference, this shoudle be [256, 256] + crop_size (int): The cropped size for frames. + num_spatial_crops (int): The number of the cropped spatial regions in each video. + """ + + def __init__( + self, + short_side_range, + crop_size, + num_spatial_crops=1, + ): + self.idx = -1 + self.short_side_range = short_side_range + self.crop_size = int(crop_size) + self.num_spatial_crops = num_spatial_crops + + def _get_controlled_crop(self, clip): + """Perform controlled crop for video tensor. + Args: + clip (Tensor): the video data, the shape is [T, C, H, W] + """ + _, _, clip_height, clip_width = clip.shape + + length = self.short_side_range[0] + + if clip_height < clip_width: + new_clip_height = int(length) + new_clip_width = int(clip_width / clip_height * new_clip_height) + new_clip = torch.nn.functional.interpolate( + clip, size=(new_clip_height, new_clip_width), mode='bilinear') + else: + new_clip_width = int(length) + new_clip_height = int(clip_height / clip_width * new_clip_width) + new_clip = torch.nn.functional.interpolate( + clip, size=(new_clip_height, new_clip_width), mode='bilinear') + x_max = int(new_clip_width - self.crop_size) + y_max = int(new_clip_height - self.crop_size) + if self.num_spatial_crops == 1: + x = x_max // 2 + y = y_max // 2 + elif self.num_spatial_crops == 3: + if self.idx == 0: + if new_clip_width == length: + x = x_max // 2 + y = 0 + elif new_clip_height == length: + x = 0 + y = y_max // 2 + elif self.idx == 1: + x = x_max // 2 + y = y_max // 2 + elif self.idx == 2: + if new_clip_width == length: + x = x_max // 2 + y = y_max + elif new_clip_height == length: + x = x_max + y = y_max // 2 + return new_clip[:, :, y:y + self.crop_size, x:x + self.crop_size] + + def set_spatial_index(self, idx): + """Set the spatial cropping index for controlled cropping.. + Args: + idx (int): the spatial index. The value should be in [0, 1, 2], means [left, center, right], respectively. + """ + self.idx = idx + + def __call__(self, clip): + return self._get_controlled_crop(clip) diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 450738e8..75e2d04d 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -29,6 +29,7 @@ class Tasks(object): image_generation = 'image-generation' image_matting = 'image-matting' ocr_detection = 'ocr-detection' + action_recognition = 'action-recognition' # nlp tasks zero_shot_classification = 'zero-shot-classification' diff --git a/modelscope/utils/hub.py b/modelscope/utils/hub.py index 01a1b1b0..868e751b 100644 --- a/modelscope/utils/hub.py +++ b/modelscope/utils/hub.py @@ -2,21 +2,39 @@ import os import os.path as osp -from typing import List, Union +from typing import List, Optional, Union -from numpy import deprecate +from requests import HTTPError from modelscope.hub.file_download import model_file_download from modelscope.hub.snapshot_download import snapshot_download -from modelscope.hub.utils.utils import get_cache_dir from modelscope.utils.config import Config from modelscope.utils.constant import ModelFile -# temp solution before the hub-cache is in place -@deprecate -def get_model_cache_dir(model_id: str): - return os.path.join(get_cache_dir(), model_id) +def create_model_if_not_exist( + api, + model_id: str, + chinese_name: str, + visibility: Optional[int] = 5, # 1-private, 5-public + license: Optional[str] = 'apache-2.0', + revision: Optional[str] = 'master'): + exists = True + try: + api.get_model(model_id=model_id, revision=revision) + except HTTPError: + exists = False + if exists: + print(f'model {model_id} already exists, skip creation.') + return False + else: + api.create_model( + model_id=model_id, + chinese_name=chinese_name, + visibility=visibility, + license=license) + print(f'model {model_id} successfully created.') + return True def read_config(model_id_or_path: str): diff --git a/modelscope/utils/registry.py b/modelscope/utils/registry.py index b26b899d..8009b084 100644 --- a/modelscope/utils/registry.py +++ b/modelscope/utils/registry.py @@ -78,7 +78,7 @@ class Registry(object): f'{self._name}[{default_group}] and will ' 'be overwritten') logger.warning(f'{self._modules[default_group][module_name]}' - 'to {module_cls}') + f'to {module_cls}') # also register module in the default group for faster access # only by module name self._modules[default_group][module_name] = module_cls diff --git a/requirements/cv.txt b/requirements/cv.txt index 5bec8ba7..513dae99 100644 --- a/requirements/cv.txt +++ b/requirements/cv.txt @@ -1,2 +1,3 @@ +decord>=0.6.0 easydict tf_slim diff --git a/tests/hub/test_hub_examples.py b/tests/hub/test_hub_examples.py new file mode 100644 index 00000000..b63445af --- /dev/null +++ b/tests/hub/test_hub_examples.py @@ -0,0 +1,33 @@ +import unittest + +from maas_hub.maas_api import MaasApi + +from modelscope.utils.hub import create_model_if_not_exist + +USER_NAME = 'maasadmin' +PASSWORD = '12345678' + + +class HubExampleTest(unittest.TestCase): + + def setUp(self): + self.api = MaasApi() + # note this is temporary before official account management is ready + self.api.login(USER_NAME, PASSWORD) + + @unittest.skip('to be used for local test only') + def test_example_model_creation(self): + # ATTENTION:change to proper model names before use + model_name = 'cv_unet_person-image-cartoon_compound-models' + model_chinese_name = '达摩卡通化模型' + model_org = 'damo' + model_id = '%s/%s' % (model_org, model_name) + + created = create_model_if_not_exist(self.api, model_id, + model_chinese_name) + if not created: + print('!! NOT created since model already exists !!') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/hub/test_hub_operation.py b/tests/hub/test_hub_operation.py index 2277860b..d44cd7c1 100644 --- a/tests/hub/test_hub_operation.py +++ b/tests/hub/test_hub_operation.py @@ -1,6 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os -import os.path as osp import subprocess import tempfile import unittest @@ -8,7 +7,6 @@ import uuid from modelscope.hub.api import HubApi, ModelScopeConfig from modelscope.hub.file_download import model_file_download -from modelscope.hub.repository import Repository from modelscope.hub.snapshot_download import snapshot_download from modelscope.hub.utils.utils import get_gitlab_domain diff --git a/tests/pipelines/test_action_recognition.py b/tests/pipelines/test_action_recognition.py new file mode 100644 index 00000000..b524ca18 --- /dev/null +++ b/tests/pipelines/test_action_recognition.py @@ -0,0 +1,58 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# !/usr/bin/env python +import os.path as osp +import shutil +import tempfile +import unittest + +import cv2 + +from modelscope.fileio import File +from modelscope.pipelines import pipeline +from modelscope.pydatasets import PyDataset +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.test_utils import test_level + + +class ActionRecognitionTest(unittest.TestCase): + + def setUp(self) -> None: + self.model_id = 'damo/cv_TAdaConv_action-recognition' + + @unittest.skip('deprecated, download model from model hub instead') + def test_run_with_direct_file_download(self): + model_path = 'https://aquila2-online-models.oss-cn-shanghai.aliyuncs.com/maas_test/pytorch_model.pt' + config_path = 'https://aquila2-online-models.oss-cn-shanghai.aliyuncs.com/maas_test/configuration.json' + with tempfile.TemporaryDirectory() as tmp_dir: + model_file = osp.join(tmp_dir, ModelFile.TORCH_MODEL_FILE) + with open(model_file, 'wb') as ofile1: + ofile1.write(File.read(model_path)) + config_file = osp.join(tmp_dir, ModelFile.CONFIGURATION) + with open(config_file, 'wb') as ofile2: + ofile2.write(File.read(config_path)) + recognition_pipeline = pipeline( + Tasks.action_recognition, model=tmp_dir) + result = recognition_pipeline( + 'data/test/videos/action_recognition_test_video.mp4') + print(f'recognition output: {result}.') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + recognition_pipeline = pipeline( + Tasks.action_recognition, model=self.model_id) + result = recognition_pipeline( + 'data/test/videos/action_recognition_test_video.mp4') + + print(f'recognition output: {result}.') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub_default_model(self): + recognition_pipeline = pipeline(Tasks.action_recognition) + result = recognition_pipeline( + 'data/test/videos/action_recognition_test_video.mp4') + + print(f'recognition output: {result}.') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_image_matting.py b/tests/pipelines/test_image_matting.py index 23ea678b..751b6975 100644 --- a/tests/pipelines/test_image_matting.py +++ b/tests/pipelines/test_image_matting.py @@ -60,7 +60,7 @@ class ImageMattingTest(unittest.TestCase): cv2.imwrite('result.png', result['output_png']) print(f'Output written to {osp.abspath("result.png")}') - @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_modelscope_dataset(self): dataset = PyDataset.load('beans', split='train', target='image') img_matting = pipeline(Tasks.image_matting, model=self.model_id) diff --git a/tests/pydatasets/test_py_dataset.py b/tests/pydatasets/test_py_dataset.py index bc38e369..4ad767fa 100644 --- a/tests/pydatasets/test_py_dataset.py +++ b/tests/pydatasets/test_py_dataset.py @@ -33,8 +33,6 @@ class ImgPreprocessor(Preprocessor): class PyDatasetTest(unittest.TestCase): - @unittest.skipUnless(test_level() >= 2, - 'skip test due to dataset api problem') def test_ds_basic(self): ms_ds_full = PyDataset.load('squad') ms_ds_full_hf = hfdata.load_dataset('squad')