mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 03:59:23 +01:00
merge with master
This commit is contained in:
@@ -49,6 +49,7 @@ class Pipelines(object):
|
||||
ocr_detection = 'resnet18-ocr-detection'
|
||||
action_recognition = 'TAdaConv_action-recognition'
|
||||
animal_recognation = 'resnet101-animal_recog'
|
||||
cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding'
|
||||
|
||||
# nlp tasks
|
||||
sentence_similarity = 'sentence-similarity'
|
||||
|
||||
3
modelscope/models/cv/cmdssl_video_embedding/__init__.py
Normal file
3
modelscope/models/cv/cmdssl_video_embedding/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .c3d import C3D
|
||||
from .resnet2p1d import resnet26_2p1d
|
||||
from .resnet3d import resnet26_3d
|
||||
121
modelscope/models/cv/cmdssl_video_embedding/c3d.py
Normal file
121
modelscope/models/cv/cmdssl_video_embedding/c3d.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def conv3x3x3(in_planes, out_planes, stride=1):
|
||||
return nn.Conv3d(
|
||||
in_planes, out_planes, kernel_size=3, stride=stride, padding=1)
|
||||
|
||||
|
||||
class C3D(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
num_classes=1000,
|
||||
dropout=0.5,
|
||||
inplanes=3,
|
||||
norm_layer=None,
|
||||
last_pool=True):
|
||||
super(C3D, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm3d
|
||||
if not last_pool and num_classes is not None:
|
||||
raise ValueError('num_classes should be None when last_pool=False')
|
||||
|
||||
self.conv1 = conv3x3x3(inplanes, 64)
|
||||
self.bn1 = norm_layer(64)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
self.pool1 = nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))
|
||||
|
||||
self.conv2 = conv3x3x3(64, 128)
|
||||
self.bn2 = norm_layer(128)
|
||||
self.relu2 = nn.ReLU(inplace=True)
|
||||
self.pool2 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))
|
||||
|
||||
self.conv3a = conv3x3x3(128, 256)
|
||||
self.bn3a = norm_layer(256)
|
||||
self.relu3a = nn.ReLU(inplace=True)
|
||||
|
||||
self.conv3b = conv3x3x3(256, 256)
|
||||
self.bn3b = norm_layer(256)
|
||||
self.relu3b = nn.ReLU(inplace=True)
|
||||
self.pool3 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))
|
||||
|
||||
self.conv4a = conv3x3x3(256, 512)
|
||||
self.bn4a = norm_layer(512)
|
||||
self.relu4a = nn.ReLU(inplace=True)
|
||||
|
||||
self.conv4b = conv3x3x3(512, 512)
|
||||
self.bn4b = norm_layer(512)
|
||||
self.relu4b = nn.ReLU(inplace=True)
|
||||
self.pool4 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))
|
||||
|
||||
self.conv5a = conv3x3x3(512, 512)
|
||||
self.bn5a = norm_layer(512)
|
||||
self.relu5a = nn.ReLU(inplace=True)
|
||||
|
||||
self.conv5b = conv3x3x3(512, 512)
|
||||
self.bn5b = norm_layer(512)
|
||||
self.relu5b = nn.ReLU(inplace=True)
|
||||
self.pool5 = nn.AdaptiveAvgPool3d((1, 1, 1)) if last_pool else None
|
||||
|
||||
if num_classes is None:
|
||||
self.dropout = None
|
||||
self.fc = None
|
||||
else:
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.fc = nn.Linear(512, num_classes)
|
||||
self.out_planes = 512
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv3d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu1(x)
|
||||
x = self.pool1(x)
|
||||
|
||||
x = self.conv2(x)
|
||||
x = self.bn2(x)
|
||||
x = self.relu2(x)
|
||||
x = self.pool2(x)
|
||||
|
||||
x = self.conv3a(x)
|
||||
x = self.bn3a(x)
|
||||
x = self.relu3a(x)
|
||||
|
||||
x = self.conv3b(x)
|
||||
x = self.bn3b(x)
|
||||
x = self.relu3b(x)
|
||||
x = self.pool3(x)
|
||||
|
||||
x = self.conv4a(x)
|
||||
x = self.bn4a(x)
|
||||
x = self.relu4a(x)
|
||||
|
||||
x = self.conv4b(x)
|
||||
x = self.bn4b(x)
|
||||
x = self.relu4b(x)
|
||||
x = self.pool4(x)
|
||||
|
||||
x = self.conv5a(x)
|
||||
x = self.bn5a(x)
|
||||
x = self.relu5a(x)
|
||||
|
||||
x = self.conv5b(x)
|
||||
x = self.bn5b(x)
|
||||
x = self.relu5b(x)
|
||||
|
||||
if self.pool5:
|
||||
x = self.pool5(x)
|
||||
x = torch.flatten(x, 1)
|
||||
if self.dropout and self.fc:
|
||||
x = self.dropout(x)
|
||||
x = self.fc(x)
|
||||
|
||||
return x
|
||||
339
modelscope/models/cv/cmdssl_video_embedding/resnet2p1d.py
Normal file
339
modelscope/models/cv/cmdssl_video_embedding/resnet2p1d.py
Normal file
@@ -0,0 +1,339 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def conv1x3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
||||
return nn.Conv3d(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size=(1, 3, 3),
|
||||
stride=(1, stride, stride),
|
||||
padding=(0, dilation, dilation),
|
||||
groups=groups,
|
||||
bias=False,
|
||||
dilation=(1, dilation, dilation))
|
||||
|
||||
|
||||
def conv3x1x1(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
||||
return nn.Conv3d(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size=(3, 1, 1),
|
||||
stride=(stride, 1, 1),
|
||||
padding=(dilation, 0, 0),
|
||||
groups=groups,
|
||||
bias=False,
|
||||
dilation=(dilation, 1, 1))
|
||||
|
||||
|
||||
def conv1x1x1(in_planes, out_planes, stride=1):
|
||||
return nn.Conv3d(
|
||||
in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self,
|
||||
inplanes,
|
||||
planes,
|
||||
stride=1,
|
||||
downsample=None,
|
||||
groups=1,
|
||||
base_width=64,
|
||||
dilation=1,
|
||||
norm_layer=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm3d
|
||||
if groups != 1 or base_width != 64:
|
||||
raise ValueError(
|
||||
'BasicBlock only supports groups=1 and base_width=64')
|
||||
if dilation > 1:
|
||||
raise NotImplementedError(
|
||||
'Dilation > 1 not supported in BasicBlock')
|
||||
|
||||
midplanes1 = (inplanes * planes * 3 * 3 * 3) // (
|
||||
inplanes * 3 * 3 + planes * 3)
|
||||
self.conv1_s = conv1x3x3(inplanes, midplanes1, stride)
|
||||
self.bn1_s = norm_layer(midplanes1)
|
||||
self.conv1_t = conv3x1x1(midplanes1, planes, stride)
|
||||
self.bn1_t = norm_layer(planes)
|
||||
|
||||
midplanes2 = (planes * planes * 3 * 3 * 3) // (
|
||||
planes * 3 * 3 + planes * 3)
|
||||
self.conv2_s = conv1x3x3(planes, midplanes2)
|
||||
self.bn2_s = norm_layer(midplanes2)
|
||||
self.conv2_t = conv3x1x1(midplanes2, planes)
|
||||
self.bn2_t = norm_layer(planes)
|
||||
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1_s(x)
|
||||
out = self.bn1_s(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv1_t(out)
|
||||
out = self.bn1_t(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2_s(out)
|
||||
out = self.bn2_s(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv2_t(out)
|
||||
out = self.bn2_t(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self,
|
||||
inplanes,
|
||||
planes,
|
||||
stride=1,
|
||||
downsample=None,
|
||||
groups=1,
|
||||
base_width=64,
|
||||
dilation=1,
|
||||
norm_layer=None):
|
||||
super(Bottleneck, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm3d
|
||||
width = int(planes * (base_width / 64.)) * groups
|
||||
|
||||
self.conv1 = conv1x1x1(inplanes, width)
|
||||
self.bn1 = norm_layer(width)
|
||||
|
||||
midplanes = (width * width * 3 * 3 * 3) // (width * 3 * 3 + width * 3)
|
||||
self.conv2_s = conv1x3x3(width, midplanes, stride, groups, dilation)
|
||||
self.bn2_s = norm_layer(midplanes)
|
||||
self.conv2_t = conv3x1x1(midplanes, width, stride, groups, dilation)
|
||||
self.bn2_t = norm_layer(width)
|
||||
|
||||
self.conv3 = conv1x1x1(width, planes * self.expansion)
|
||||
self.bn3 = norm_layer(planes * self.expansion)
|
||||
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2_s(out)
|
||||
out = self.bn2_s(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv2_t(out)
|
||||
out = self.bn2_t(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet2p1d(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
block,
|
||||
layers,
|
||||
num_classes=None,
|
||||
zero_init_residual=True,
|
||||
groups=1,
|
||||
width_per_group=64,
|
||||
replace_stride_with_dilation=None,
|
||||
dropout=0.5,
|
||||
inplanes=3,
|
||||
first_stride=2,
|
||||
norm_layer=None,
|
||||
last_pool=True):
|
||||
super(ResNet2p1d, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm3d
|
||||
if not last_pool and num_classes is not None:
|
||||
raise ValueError('num_classes should be None when last_pool=False')
|
||||
self._norm_layer = norm_layer
|
||||
self.first_stride = first_stride
|
||||
|
||||
self.inplanes = 64
|
||||
self.dilation = 1
|
||||
if replace_stride_with_dilation is None:
|
||||
replace_stride_with_dilation = [False, False, False]
|
||||
if len(replace_stride_with_dilation) != 3:
|
||||
raise ValueError('replace_stride_with_dilation should be None '
|
||||
'or a 3-element tuple, got {}'.format(
|
||||
replace_stride_with_dilation))
|
||||
self.groups = groups
|
||||
self.base_width = width_per_group
|
||||
|
||||
midplanes = (3 * self.inplanes * 3 * 7 * 7) // (3 * 7 * 7
|
||||
+ self.inplanes * 3)
|
||||
self.conv1_s = nn.Conv3d(
|
||||
inplanes,
|
||||
midplanes,
|
||||
kernel_size=(1, 7, 7),
|
||||
stride=(1, first_stride, first_stride),
|
||||
padding=(0, 3, 3),
|
||||
bias=False)
|
||||
self.bn1_s = norm_layer(midplanes)
|
||||
self.conv1_t = nn.Conv3d(
|
||||
midplanes,
|
||||
self.inplanes,
|
||||
kernel_size=(3, 1, 1),
|
||||
stride=(1, 1, 1),
|
||||
padding=(1, 0, 0),
|
||||
bias=False)
|
||||
self.bn1_t = norm_layer(self.inplanes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool3d(
|
||||
kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1))
|
||||
|
||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||
self.layer2 = self._make_layer(
|
||||
block,
|
||||
128,
|
||||
layers[1],
|
||||
stride=2,
|
||||
dilate=replace_stride_with_dilation[0])
|
||||
self.layer3 = self._make_layer(
|
||||
block,
|
||||
256,
|
||||
layers[2],
|
||||
stride=2,
|
||||
dilate=replace_stride_with_dilation[1])
|
||||
self.layer4 = self._make_layer(
|
||||
block,
|
||||
512,
|
||||
layers[3],
|
||||
stride=2,
|
||||
dilate=replace_stride_with_dilation[2])
|
||||
self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1)) if last_pool else None
|
||||
if num_classes is None:
|
||||
self.dropout = None
|
||||
self.fc = None
|
||||
else:
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||
self.out_planes = 512 * block.expansion
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv3d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
nn.init.constant_(m.bn3.weight, 0)
|
||||
elif isinstance(m, BasicBlock):
|
||||
nn.init.constant_(m.bn2_t.weight, 0)
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
||||
norm_layer = self._norm_layer
|
||||
downsample = None
|
||||
previous_dilation = self.dilation
|
||||
if dilate:
|
||||
self.dilation *= stride
|
||||
stride = 1
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
conv1x1x1(self.inplanes, planes * block.expansion, stride),
|
||||
norm_layer(planes * block.expansion))
|
||||
|
||||
layers = []
|
||||
layers.append(
|
||||
block(self.inplanes, planes, stride, downsample, self.groups,
|
||||
self.base_width, previous_dilation, norm_layer))
|
||||
self.inplanes = planes * block.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(
|
||||
block(
|
||||
self.inplanes,
|
||||
planes,
|
||||
groups=self.groups,
|
||||
base_width=self.base_width,
|
||||
dilation=self.dilation,
|
||||
norm_layer=norm_layer))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1_s(x)
|
||||
x = self.bn1_s(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv1_t(x)
|
||||
x = self.bn1_t(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
if self.avgpool:
|
||||
x = self.avgpool(x)
|
||||
x = torch.flatten(x, 1)
|
||||
if self.dropout and self.fc:
|
||||
x = self.dropout(x)
|
||||
x = self.fc(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def resnet10_2p1d(**kwargs):
|
||||
return ResNet2p1d(BasicBlock, [1, 1, 1, 1], **kwargs)
|
||||
|
||||
|
||||
def resnet18_2p1d(**kwargs):
|
||||
return ResNet2p1d(BasicBlock, [2, 2, 2, 2], **kwargs)
|
||||
|
||||
|
||||
def resnet26_2p1d(**kwargs):
|
||||
return ResNet2p1d(Bottleneck, [2, 2, 2, 2], **kwargs)
|
||||
|
||||
|
||||
def resnet34_2p1d(**kwargs):
|
||||
return ResNet2p1d(BasicBlock, [3, 4, 6, 3], **kwargs)
|
||||
|
||||
|
||||
def resnet50_2p1d(**kwargs):
|
||||
return ResNet2p1d(Bottleneck, [3, 4, 6, 3], **kwargs)
|
||||
|
||||
|
||||
def resnet101_2p1d(**kwargs):
|
||||
return ResNet2p1d(Bottleneck, [3, 4, 23, 3], **kwargs)
|
||||
|
||||
|
||||
def resnet152_2p1d(**kwargs):
|
||||
return ResNet2p1d(Bottleneck, [3, 8, 36, 3], **kwargs)
|
||||
|
||||
|
||||
def resnet200_2p1d(**kwargs):
|
||||
return ResNet2p1d(Bottleneck, [3, 24, 36, 3], **kwargs)
|
||||
284
modelscope/models/cv/cmdssl_video_embedding/resnet3d.py
Normal file
284
modelscope/models/cv/cmdssl_video_embedding/resnet3d.py
Normal file
@@ -0,0 +1,284 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def conv3x3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
||||
return nn.Conv3d(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=dilation,
|
||||
groups=groups,
|
||||
bias=False,
|
||||
dilation=dilation)
|
||||
|
||||
|
||||
def conv1x1x1(in_planes, out_planes, stride=1):
|
||||
return nn.Conv3d(
|
||||
in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self,
|
||||
inplanes,
|
||||
planes,
|
||||
stride=1,
|
||||
downsample=None,
|
||||
groups=1,
|
||||
base_width=64,
|
||||
dilation=1,
|
||||
norm_layer=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm3d
|
||||
if groups != 1 or base_width != 64:
|
||||
raise ValueError(
|
||||
'BasicBlock only supports groups=1 and base_width=64')
|
||||
if dilation > 1:
|
||||
raise NotImplementedError(
|
||||
'Dilation > 1 not supported in BasicBlock')
|
||||
self.conv1 = conv3x3x3(inplanes, planes, stride)
|
||||
self.bn1 = norm_layer(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3x3(planes, planes)
|
||||
self.bn2 = norm_layer(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self,
|
||||
inplanes,
|
||||
planes,
|
||||
stride=1,
|
||||
downsample=None,
|
||||
groups=1,
|
||||
base_width=64,
|
||||
dilation=1,
|
||||
norm_layer=None):
|
||||
super(Bottleneck, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm3d
|
||||
width = int(planes * (base_width / 64.)) * groups
|
||||
self.conv1 = conv1x1x1(inplanes, width)
|
||||
self.bn1 = norm_layer(width)
|
||||
self.conv2 = conv3x3x3(width, width, stride, groups, dilation)
|
||||
self.bn2 = norm_layer(width)
|
||||
self.conv3 = conv1x1x1(width, planes * self.expansion)
|
||||
self.bn3 = norm_layer(planes * self.expansion)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet3d(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
block,
|
||||
layers,
|
||||
num_classes=1000,
|
||||
zero_init_residual=True,
|
||||
groups=1,
|
||||
width_per_group=64,
|
||||
replace_stride_with_dilation=None,
|
||||
dropout=0.5,
|
||||
inplanes=3,
|
||||
first_stride=2,
|
||||
norm_layer=None,
|
||||
last_pool=True):
|
||||
super(ResNet3d, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm3d
|
||||
if not last_pool and num_classes is not None:
|
||||
raise ValueError('num_classes should be None when last_pool=False')
|
||||
self._norm_layer = norm_layer
|
||||
|
||||
self.inplanes = 64
|
||||
self.dilation = 1
|
||||
if replace_stride_with_dilation is None:
|
||||
replace_stride_with_dilation = [False, False, False]
|
||||
if len(replace_stride_with_dilation) != 3:
|
||||
raise ValueError('replace_stride_with_dilation should be None '
|
||||
'or a 3-element tuple, got {}'.format(
|
||||
replace_stride_with_dilation))
|
||||
self.groups = groups
|
||||
self.base_width = width_per_group
|
||||
self.conv1 = nn.Conv3d(
|
||||
inplanes,
|
||||
self.inplanes,
|
||||
kernel_size=(3, 7, 7),
|
||||
stride=(1, first_stride, first_stride),
|
||||
padding=(1, 3, 3),
|
||||
bias=False)
|
||||
self.bn1 = norm_layer(self.inplanes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool3d(
|
||||
kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1))
|
||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||
self.layer2 = self._make_layer(
|
||||
block,
|
||||
128,
|
||||
layers[1],
|
||||
stride=2,
|
||||
dilate=replace_stride_with_dilation[0])
|
||||
self.layer3 = self._make_layer(
|
||||
block,
|
||||
256,
|
||||
layers[2],
|
||||
stride=2,
|
||||
dilate=replace_stride_with_dilation[1])
|
||||
self.layer4 = self._make_layer(
|
||||
block,
|
||||
512,
|
||||
layers[3],
|
||||
stride=2,
|
||||
dilate=replace_stride_with_dilation[2])
|
||||
self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1)) if last_pool else None
|
||||
if num_classes is None:
|
||||
self.dropout = None
|
||||
self.fc = None
|
||||
else:
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||
self.out_planes = 512 * block.expansion
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv3d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
nn.init.constant_(m.bn3.weight, 0)
|
||||
elif isinstance(m, BasicBlock):
|
||||
nn.init.constant_(m.bn2.weight, 0)
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
||||
norm_layer = self._norm_layer
|
||||
downsample = None
|
||||
previous_dilation = self.dilation
|
||||
if dilate:
|
||||
self.dilation *= stride
|
||||
stride = 1
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
conv1x1x1(self.inplanes, planes * block.expansion, stride),
|
||||
norm_layer(planes * block.expansion))
|
||||
|
||||
layers = []
|
||||
layers.append(
|
||||
block(self.inplanes, planes, stride, downsample, self.groups,
|
||||
self.base_width, previous_dilation, norm_layer))
|
||||
self.inplanes = planes * block.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(
|
||||
block(
|
||||
self.inplanes,
|
||||
planes,
|
||||
groups=self.groups,
|
||||
base_width=self.base_width,
|
||||
dilation=self.dilation,
|
||||
norm_layer=norm_layer))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
if self.avgpool:
|
||||
x = self.avgpool(x)
|
||||
x = torch.flatten(x, 1)
|
||||
if self.dropout and self.fc:
|
||||
x = self.dropout(x)
|
||||
x = self.fc(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def resnet10_3d(**kwargs):
|
||||
return ResNet3d(BasicBlock, [1, 1, 1, 1], **kwargs)
|
||||
|
||||
|
||||
def resnet18_3d(**kwargs):
|
||||
return ResNet3d(BasicBlock, [2, 2, 2, 2], **kwargs)
|
||||
|
||||
|
||||
def resnet26_3d(**kwargs):
|
||||
return ResNet3d(Bottleneck, [2, 2, 2, 2], **kwargs)
|
||||
|
||||
|
||||
def resnet34_3d(**kwargs):
|
||||
return ResNet3d(BasicBlock, [3, 4, 6, 3], **kwargs)
|
||||
|
||||
|
||||
def resnet50_3d(**kwargs):
|
||||
return ResNet3d(Bottleneck, [3, 4, 6, 3], **kwargs)
|
||||
|
||||
|
||||
def resnet101_3d(**kwargs):
|
||||
return ResNet3d(Bottleneck, [3, 4, 23, 3], **kwargs)
|
||||
|
||||
|
||||
def resnet152_3d(**kwargs):
|
||||
return ResNet3d(Bottleneck, [3, 8, 36, 3], **kwargs)
|
||||
|
||||
|
||||
def resnet200_3d(**kwargs):
|
||||
return ResNet3d(Bottleneck, [3, 24, 36, 3], **kwargs)
|
||||
@@ -59,6 +59,11 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
Tasks.visual_question_answering:
|
||||
(Pipelines.visual_question_answering,
|
||||
'damo/mplug_visual-question-answering_coco_large_en'),
|
||||
Tasks.video_embedding: (Pipelines.cmdssl_video_embedding,
|
||||
'damo/cv_r2p1d_video_embedding'),
|
||||
Tasks.text_to_image_synthesis:
|
||||
(Pipelines.text_to_image_synthesis,
|
||||
'damo/cv_imagen_text-to-image-synthesis_tiny')
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
try:
|
||||
from .action_recognition_pipeline import ActionRecognitionPipeline
|
||||
from .animal_recog_pipeline import AnimalRecogPipeline
|
||||
from .cmdssl_video_embedding_pipleline import CMDSSLVideoEmbeddingPipeline
|
||||
except ModuleNotFoundError as e:
|
||||
if str(e) == "No module named 'torch'":
|
||||
pass
|
||||
|
||||
@@ -2,9 +2,6 @@ import math
|
||||
import os.path as osp
|
||||
from typing import Any, Dict
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
@@ -32,7 +29,9 @@ class ActionRecognitionPipeline(Pipeline):
|
||||
config_path = osp.join(self.model, ModelFile.CONFIGURATION)
|
||||
logger.info(f'loading config from {config_path}')
|
||||
self.cfg = Config.from_file(config_path)
|
||||
self.infer_model = BaseVideoModel(cfg=self.cfg).cuda()
|
||||
self.device = torch.device(
|
||||
'cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.infer_model = BaseVideoModel(cfg=self.cfg).to(self.device)
|
||||
self.infer_model.eval()
|
||||
self.infer_model.load_state_dict(torch.load(model_path)['model_state'])
|
||||
self.label_mapping = self.cfg.label_mapping
|
||||
@@ -40,7 +39,7 @@ class ActionRecognitionPipeline(Pipeline):
|
||||
|
||||
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||
if isinstance(input, str):
|
||||
video_input_data = ReadVideoData(self.cfg, input).cuda()
|
||||
video_input_data = ReadVideoData(self.cfg, input).to(self.device)
|
||||
else:
|
||||
raise TypeError(f'input should be a str,'
|
||||
f' but got {type(input)}')
|
||||
|
||||
157
modelscope/pipelines/cv/cmdssl_video_embedding_pipleline.py
Normal file
157
modelscope/pipelines/cv/cmdssl_video_embedding_pipleline.py
Normal file
@@ -0,0 +1,157 @@
|
||||
import math
|
||||
import os.path as osp
|
||||
from typing import Any, Dict
|
||||
|
||||
import cv2
|
||||
import decord
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
import torchvision.transforms.functional as TF
|
||||
from decord import VideoReader, cpu
|
||||
from PIL import Image
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.cv.cmdssl_video_embedding.resnet2p1d import \
|
||||
resnet26_2p1d
|
||||
from modelscope.pipelines.base import Input
|
||||
from modelscope.pipelines.outputs import OutputKeys
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
from ..base import Pipeline
|
||||
from ..builder import PIPELINES
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.video_embedding, module_name=Pipelines.cmdssl_video_embedding)
|
||||
class CMDSSLVideoEmbeddingPipeline(Pipeline):
|
||||
|
||||
def __init__(self, model: str):
|
||||
super().__init__(model=model)
|
||||
model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE)
|
||||
logger.info(f'loading model from {model_path}')
|
||||
config_path = osp.join(self.model, ModelFile.CONFIGURATION)
|
||||
logger.info(f'loading config from {config_path}')
|
||||
self.cfg = Config.from_file(config_path)
|
||||
self.model = resnet26_2p1d(num_classes=None, last_pool=True)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
self._device = torch.device('cuda')
|
||||
else:
|
||||
self._device = torch.device('cpu')
|
||||
self.model = self.model.to(self._device).eval().requires_grad_(False)
|
||||
self.model.load_state_dict(torch.load(model_path))
|
||||
logger.info('load model done')
|
||||
|
||||
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||
decord.bridge.set_bridge('native')
|
||||
|
||||
transforms = VCompose([
|
||||
VRescale(size=self.cfg.DATA.scale_size),
|
||||
VCenterCrop(size=self.cfg.DATA.crop_size),
|
||||
VToTensor(),
|
||||
VNormalize(mean=self.cfg.DATA.mean, std=self.cfg.DATA.std)
|
||||
])
|
||||
|
||||
clip_len = (self.cfg.DATA.video_frames
|
||||
- 1) * self.cfg.DATA.video_stride + 1
|
||||
vr = VideoReader(input, ctx=cpu(0))
|
||||
if len(vr) <= clip_len:
|
||||
init_frames = np.zeros(self.cfg.DATA.multi_crop, dtype=int)
|
||||
else:
|
||||
init_frames = np.linspace(0,
|
||||
len(vr) - clip_len,
|
||||
self.cfg.DATA.multi_crop + 1)
|
||||
init_frames = ((init_frames[1:] + init_frames[:-1])
|
||||
/ 2.).astype(int)
|
||||
|
||||
indices = np.arange(0, clip_len, self.cfg.DATA.video_stride)
|
||||
indices = (init_frames[:, None] + indices[None, :]).reshape(-1)
|
||||
indices[indices >= len(vr)] = 0
|
||||
|
||||
frames = torch.from_numpy(vr.get_batch(indices).asnumpy()).chunk(
|
||||
self.cfg.DATA.multi_crop, dim=0)
|
||||
frames = [
|
||||
transforms([Image.fromarray(f) for f in u.numpy()]) for u in frames
|
||||
]
|
||||
frames = torch.stack(frames, dim=0)
|
||||
result = {'video_data': frames}
|
||||
return result
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
frames = input['video_data'].to(self._device)
|
||||
feature = self.model(frames)
|
||||
feature = feature.mean(0)
|
||||
return {OutputKeys.VIDEO_EMBEDDING: feature.data.cpu().numpy()}
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return inputs
|
||||
|
||||
|
||||
class VCompose(object):
|
||||
|
||||
def __init__(self, transforms):
|
||||
self.transforms = transforms
|
||||
|
||||
def __call__(self, item):
|
||||
for t in self.transforms:
|
||||
item = t(item)
|
||||
return item
|
||||
|
||||
|
||||
class VRescale(object):
|
||||
|
||||
def __init__(self, size=128):
|
||||
self.size = size
|
||||
|
||||
def __call__(self, vclip):
|
||||
w, h = vclip[0].size
|
||||
scale = self.size / min(w, h)
|
||||
out_w, out_h = int(round(w * scale)), int(round(h * scale))
|
||||
vclip = [u.resize((out_w, out_h), Image.BILINEAR) for u in vclip]
|
||||
return vclip
|
||||
|
||||
|
||||
class VCenterCrop(object):
|
||||
|
||||
def __init__(self, size=112):
|
||||
self.size = size
|
||||
|
||||
def __call__(self, vclip):
|
||||
w, h = vclip[0].size
|
||||
assert min(w, h) >= self.size
|
||||
x1 = (w - self.size) // 2
|
||||
y1 = (h - self.size) // 2
|
||||
vclip = [
|
||||
u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in vclip
|
||||
]
|
||||
return vclip
|
||||
|
||||
|
||||
class VToTensor(object):
|
||||
|
||||
def __call__(self, vclip):
|
||||
vclip = torch.stack([TF.to_tensor(u) for u in vclip], dim=1)
|
||||
return vclip
|
||||
|
||||
|
||||
class VNormalize(object):
|
||||
|
||||
def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
|
||||
def __call__(self, vclip):
|
||||
assert vclip.min() > -0.1 and vclip.max() < 1.1, \
|
||||
'vclip values should be in [0, 1]'
|
||||
vclip = vclip.clone()
|
||||
if not isinstance(self.mean, torch.Tensor):
|
||||
self.mean = vclip.new_tensor(self.mean).view(-1, 1, 1, 1)
|
||||
if not isinstance(self.std, torch.Tensor):
|
||||
self.std = vclip.new_tensor(self.std).view(-1, 1, 1, 1)
|
||||
vclip.sub_(self.mean).div_(self.std)
|
||||
return vclip
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, Union
|
||||
from typing import Any, Dict
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.pipelines.base import Input
|
||||
@@ -23,7 +23,7 @@ class TextToImageSynthesisPipeline(Pipeline):
|
||||
pipe_model = model
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'execpting a Model instance or str, but get {type(model)}.')
|
||||
f'expecting a Model instance or str, but get {type(model)}.')
|
||||
|
||||
super().__init__(model=pipe_model)
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ class OutputKeys(object):
|
||||
RESPONSE = 'response'
|
||||
PREDICTION = 'prediction'
|
||||
DIALOG_STATES = 'dialog_states'
|
||||
VIDEO_EMBEDDING = 'video_embedding'
|
||||
|
||||
|
||||
TASK_OUTPUTS = {
|
||||
@@ -91,6 +92,12 @@ TASK_OUTPUTS = {
|
||||
# }
|
||||
Tasks.ocr_detection: [OutputKeys.POLYGONS],
|
||||
|
||||
# video embedding result for single video
|
||||
# {
|
||||
# "video_embedding": np.array with shape [D],
|
||||
# }
|
||||
Tasks.video_embedding: [OutputKeys.VIDEO_EMBEDDING],
|
||||
|
||||
# ============ nlp tasks ===================
|
||||
|
||||
# text classification result for single sample
|
||||
|
||||
@@ -31,6 +31,7 @@ class Tasks(object):
|
||||
image_matting = 'image-matting'
|
||||
ocr_detection = 'ocr-detection'
|
||||
action_recognition = 'action-recognition'
|
||||
video_embedding = 'video-embedding'
|
||||
|
||||
# nlp tasks
|
||||
word_segmentation = 'word-segmentation'
|
||||
|
||||
@@ -1,14 +1,10 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
# !/usr/bin/env python
|
||||
import os.path as osp
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import cv2
|
||||
|
||||
from modelscope.fileio import File
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
@@ -45,7 +41,7 @@ class ActionRecognitionTest(unittest.TestCase):
|
||||
|
||||
print(f'recognition output: {result}.')
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run_modelhub_default_model(self):
|
||||
recognition_pipeline = pipeline(Tasks.action_recognition)
|
||||
result = recognition_pipeline(
|
||||
|
||||
30
tests/pipelines/test_cmdssl_video_embedding.py
Normal file
30
tests/pipelines/test_cmdssl_video_embedding.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
# !/usr/bin/env python
|
||||
import os.path as osp
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import cv2
|
||||
|
||||
from modelscope.fileio import File
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class CMDSSLVideoEmbeddingTest(unittest.TestCase):
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_modelhub(self):
|
||||
videossl_pipeline = pipeline(
|
||||
Tasks.video_embedding, model='damo/cv_r2p1d_video_embedding')
|
||||
result = videossl_pipeline(
|
||||
'data/test/videos/action_recognition_test_video.mp4')
|
||||
|
||||
print(f'video embedding output: {result}.')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user