mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
Merge branch 'master-gitlab' into merge_master_github_1128
This commit is contained in:
3
data/test/images/image_depth_estimation.jpg
Normal file
3
data/test/images/image_depth_estimation.jpg
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:3b230497f6ca10be42aed92b86db435d74fd7306746a059b4ad1e0d6b0652806
|
||||
size 35694
|
||||
@@ -36,6 +36,7 @@ class Models(object):
|
||||
swinL_semantic_segmentation = 'swinL-semantic-segmentation'
|
||||
vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation'
|
||||
text_driven_segmentation = 'text-driven-segmentation'
|
||||
newcrfs_depth_estimation = 'newcrfs-depth-estimation'
|
||||
resnet50_bert = 'resnet50-bert'
|
||||
referring_video_object_segmentation = 'swinT-referring-video-object-segmentation'
|
||||
fer = 'fer'
|
||||
@@ -210,6 +211,7 @@ class Pipelines(object):
|
||||
video_summarization = 'googlenet_pgl_video_summarization'
|
||||
language_guided_video_summarization = 'clip-it-video-summarization'
|
||||
image_semantic_segmentation = 'image-semantic-segmentation'
|
||||
image_depth_estimation = 'image-depth-estimation'
|
||||
image_reid_person = 'passvitb-image-reid-person'
|
||||
image_inpainting = 'fft-inpainting'
|
||||
text_driven_segmentation = 'text-driven-segmentation'
|
||||
|
||||
@@ -5,10 +5,10 @@ from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.models.builder import MODELS, build_model
|
||||
from modelscope.models.builder import build_model
|
||||
from modelscope.utils.checkpoint import save_checkpoint, save_pretrained
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile, Tasks
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke, ModelFile
|
||||
from modelscope.utils.device import verify_device
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
@@ -94,6 +94,10 @@ class Model(ABC):
|
||||
if prefetched is not None:
|
||||
kwargs.pop('model_prefetched')
|
||||
|
||||
invoked_by = kwargs.get(Invoke.KEY)
|
||||
if invoked_by is not None:
|
||||
kwargs.pop(Invoke.KEY)
|
||||
|
||||
if osp.exists(model_name_or_path):
|
||||
local_model_dir = model_name_or_path
|
||||
else:
|
||||
@@ -101,7 +105,13 @@ class Model(ABC):
|
||||
raise RuntimeError(
|
||||
'Expecting model is pre-fetched locally, but is not found.'
|
||||
)
|
||||
local_model_dir = snapshot_download(model_name_or_path, revision)
|
||||
|
||||
if invoked_by is not None:
|
||||
invoked_by = '%s/%s' % (Invoke.KEY, invoked_by)
|
||||
else:
|
||||
invoked_by = '%s/%s' % (Invoke.KEY, Invoke.PRETRAINED)
|
||||
local_model_dir = snapshot_download(
|
||||
model_name_or_path, revision, user_agent=invoked_by)
|
||||
logger.info(f'initialize model from {local_model_dir}')
|
||||
if cfg_dict is not None:
|
||||
cfg = cfg_dict
|
||||
@@ -133,6 +143,7 @@ class Model(ABC):
|
||||
model.cfg = cfg
|
||||
|
||||
model.name = model_name_or_path
|
||||
model.model_dir = local_model_dir
|
||||
return model
|
||||
|
||||
def save_pretrained(self,
|
||||
|
||||
@@ -224,8 +224,8 @@ class BodyKeypointsDetection3D(TorchModel):
|
||||
lst_pose2d_cannoical.append(pose2d_canonical[:,
|
||||
i - pad:i + pad + 1])
|
||||
|
||||
input_pose2d_rr = torch.concat(lst_pose2d_cannoical, axis=0)
|
||||
input_pose2d_cannoical = torch.concat(lst_pose2d_cannoical, axis=0)
|
||||
input_pose2d_rr = torch.cat(lst_pose2d_cannoical, axis=0)
|
||||
input_pose2d_cannoical = torch.cat(lst_pose2d_cannoical, axis=0)
|
||||
|
||||
if self.cfg.model.MODEL.USE_CANONICAL_COORDS:
|
||||
input_pose2d_abs = input_pose2d_cannoical.clone()
|
||||
|
||||
1
modelscope/models/cv/image_depth_estimation/__init__.py
Normal file
1
modelscope/models/cv/image_depth_estimation/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
@@ -0,0 +1 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
@@ -0,0 +1,215 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .newcrf_layers import NewCRF
|
||||
from .swin_transformer import SwinTransformer
|
||||
from .uper_crf_head import PSP
|
||||
|
||||
|
||||
class NewCRFDepth(nn.Module):
|
||||
"""
|
||||
Depth network based on neural window FC-CRFs architecture.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
version=None,
|
||||
inv_depth=False,
|
||||
pretrained=None,
|
||||
frozen_stages=-1,
|
||||
min_depth=0.1,
|
||||
max_depth=100.0,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.inv_depth = inv_depth
|
||||
self.with_auxiliary_head = False
|
||||
self.with_neck = False
|
||||
|
||||
norm_cfg = dict(type='BN', requires_grad=True)
|
||||
# norm_cfg = dict(type='GN', requires_grad=True, num_groups=8)
|
||||
|
||||
window_size = int(version[-2:])
|
||||
|
||||
if version[:-2] == 'base':
|
||||
embed_dim = 128
|
||||
depths = [2, 2, 18, 2]
|
||||
num_heads = [4, 8, 16, 32]
|
||||
in_channels = [128, 256, 512, 1024]
|
||||
elif version[:-2] == 'large':
|
||||
embed_dim = 192
|
||||
depths = [2, 2, 18, 2]
|
||||
num_heads = [6, 12, 24, 48]
|
||||
in_channels = [192, 384, 768, 1536]
|
||||
elif version[:-2] == 'tiny':
|
||||
embed_dim = 96
|
||||
depths = [2, 2, 6, 2]
|
||||
num_heads = [3, 6, 12, 24]
|
||||
in_channels = [96, 192, 384, 768]
|
||||
|
||||
backbone_cfg = dict(
|
||||
embed_dim=embed_dim,
|
||||
depths=depths,
|
||||
num_heads=num_heads,
|
||||
window_size=window_size,
|
||||
ape=False,
|
||||
drop_path_rate=0.3,
|
||||
patch_norm=True,
|
||||
use_checkpoint=False,
|
||||
frozen_stages=frozen_stages)
|
||||
|
||||
embed_dim = 512
|
||||
decoder_cfg = dict(
|
||||
in_channels=in_channels,
|
||||
in_index=[0, 1, 2, 3],
|
||||
pool_scales=(1, 2, 3, 6),
|
||||
channels=embed_dim,
|
||||
dropout_ratio=0.0,
|
||||
num_classes=32,
|
||||
norm_cfg=norm_cfg,
|
||||
align_corners=False)
|
||||
|
||||
self.backbone = SwinTransformer(**backbone_cfg)
|
||||
# v_dim = decoder_cfg['num_classes'] * 4
|
||||
win = 7
|
||||
crf_dims = [128, 256, 512, 1024]
|
||||
v_dims = [64, 128, 256, embed_dim]
|
||||
self.crf3 = NewCRF(
|
||||
input_dim=in_channels[3],
|
||||
embed_dim=crf_dims[3],
|
||||
window_size=win,
|
||||
v_dim=v_dims[3],
|
||||
num_heads=32)
|
||||
self.crf2 = NewCRF(
|
||||
input_dim=in_channels[2],
|
||||
embed_dim=crf_dims[2],
|
||||
window_size=win,
|
||||
v_dim=v_dims[2],
|
||||
num_heads=16)
|
||||
self.crf1 = NewCRF(
|
||||
input_dim=in_channels[1],
|
||||
embed_dim=crf_dims[1],
|
||||
window_size=win,
|
||||
v_dim=v_dims[1],
|
||||
num_heads=8)
|
||||
self.crf0 = NewCRF(
|
||||
input_dim=in_channels[0],
|
||||
embed_dim=crf_dims[0],
|
||||
window_size=win,
|
||||
v_dim=v_dims[0],
|
||||
num_heads=4)
|
||||
|
||||
self.decoder = PSP(**decoder_cfg)
|
||||
self.disp_head1 = DispHead(input_dim=crf_dims[0])
|
||||
|
||||
self.up_mode = 'bilinear'
|
||||
if self.up_mode == 'mask':
|
||||
self.mask_head = nn.Sequential(
|
||||
nn.Conv2d(crf_dims[0], 64, 3, padding=1),
|
||||
nn.ReLU(inplace=True), nn.Conv2d(64, 16 * 9, 1, padding=0))
|
||||
|
||||
self.min_depth = min_depth
|
||||
self.max_depth = max_depth
|
||||
|
||||
self.init_weights(pretrained=pretrained)
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
"""Initialize the weights in backbone and heads.
|
||||
|
||||
Args:
|
||||
pretrained (str, optional): Path to pre-trained weights.
|
||||
Defaults to None.
|
||||
"""
|
||||
# print(f'== Load encoder backbone from: {pretrained}')
|
||||
self.backbone.init_weights(pretrained=pretrained)
|
||||
self.decoder.init_weights()
|
||||
if self.with_auxiliary_head:
|
||||
if isinstance(self.auxiliary_head, nn.ModuleList):
|
||||
for aux_head in self.auxiliary_head:
|
||||
aux_head.init_weights()
|
||||
else:
|
||||
self.auxiliary_head.init_weights()
|
||||
|
||||
def upsample_mask(self, disp, mask):
|
||||
""" Upsample disp [H/4, W/4, 1] -> [H, W, 1] using convex combination """
|
||||
N, _, H, W = disp.shape
|
||||
mask = mask.view(N, 1, 9, 4, 4, H, W)
|
||||
mask = torch.softmax(mask, dim=2)
|
||||
|
||||
up_disp = F.unfold(disp, kernel_size=3, padding=1)
|
||||
up_disp = up_disp.view(N, 1, 9, 1, 1, H, W)
|
||||
|
||||
up_disp = torch.sum(mask * up_disp, dim=2)
|
||||
up_disp = up_disp.permute(0, 1, 4, 2, 5, 3)
|
||||
return up_disp.reshape(N, 1, 4 * H, 4 * W)
|
||||
|
||||
def forward(self, imgs):
|
||||
|
||||
feats = self.backbone(imgs)
|
||||
if self.with_neck:
|
||||
feats = self.neck(feats)
|
||||
|
||||
ppm_out = self.decoder(feats)
|
||||
|
||||
e3 = self.crf3(feats[3], ppm_out)
|
||||
e3 = nn.PixelShuffle(2)(e3)
|
||||
e2 = self.crf2(feats[2], e3)
|
||||
e2 = nn.PixelShuffle(2)(e2)
|
||||
e1 = self.crf1(feats[1], e2)
|
||||
e1 = nn.PixelShuffle(2)(e1)
|
||||
e0 = self.crf0(feats[0], e1)
|
||||
|
||||
if self.up_mode == 'mask':
|
||||
mask = self.mask_head(e0)
|
||||
d1 = self.disp_head1(e0, 1)
|
||||
d1 = self.upsample_mask(d1, mask)
|
||||
else:
|
||||
d1 = self.disp_head1(e0, 4)
|
||||
|
||||
depth = d1 * self.max_depth
|
||||
|
||||
return depth
|
||||
|
||||
|
||||
class DispHead(nn.Module):
|
||||
|
||||
def __init__(self, input_dim=100):
|
||||
super(DispHead, self).__init__()
|
||||
# self.norm1 = nn.BatchNorm2d(input_dim)
|
||||
self.conv1 = nn.Conv2d(input_dim, 1, 3, padding=1)
|
||||
# self.relu = nn.ReLU(inplace=True)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, x, scale):
|
||||
# x = self.relu(self.norm1(x))
|
||||
x = self.sigmoid(self.conv1(x))
|
||||
if scale > 1:
|
||||
x = upsample(x, scale_factor=scale)
|
||||
return x
|
||||
|
||||
|
||||
class DispUnpack(nn.Module):
|
||||
|
||||
def __init__(self, input_dim=100, hidden_dim=128):
|
||||
super(DispUnpack, self).__init__()
|
||||
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
|
||||
self.conv2 = nn.Conv2d(hidden_dim, 16, 3, padding=1)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
self.pixel_shuffle = nn.PixelShuffle(4)
|
||||
|
||||
def forward(self, x, output_size):
|
||||
x = self.relu(self.conv1(x))
|
||||
x = self.sigmoid(self.conv2(x)) # [b, 16, h/4, w/4]
|
||||
# x = torch.reshape(x, [x.shape[0], 1, x.shape[2]*4, x.shape[3]*4])
|
||||
x = self.pixel_shuffle(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def upsample(x, scale_factor=2, mode='bilinear', align_corners=False):
|
||||
"""Upsample input tensor by a factor of 2
|
||||
"""
|
||||
return F.interpolate(
|
||||
x, scale_factor=scale_factor, mode=mode, align_corners=align_corners)
|
||||
@@ -0,0 +1,504 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
""" Multilayer perceptron."""
|
||||
|
||||
def __init__(self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.GELU,
|
||||
drop=0.):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
def window_partition(x, window_size):
|
||||
"""
|
||||
Args:
|
||||
x: (B, H, W, C)
|
||||
window_size (int): window size
|
||||
|
||||
Returns:
|
||||
windows: (num_windows*B, window_size, window_size, C)
|
||||
"""
|
||||
B, H, W, C = x.shape
|
||||
x = x.view(B, H // window_size, window_size, W // window_size, window_size,
|
||||
C)
|
||||
windows = x.permute(0, 1, 3, 2, 4,
|
||||
5).contiguous().view(-1, window_size, window_size, C)
|
||||
return windows
|
||||
|
||||
|
||||
def window_reverse(windows, window_size, H, W):
|
||||
"""
|
||||
Args:
|
||||
windows: (num_windows*B, window_size, window_size, C)
|
||||
window_size (int): Window size
|
||||
H (int): Height of image
|
||||
W (int): Width of image
|
||||
|
||||
Returns:
|
||||
x: (B, H, W, C)
|
||||
"""
|
||||
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
||||
x = windows.view(B, H // window_size, W // window_size, window_size,
|
||||
window_size, -1)
|
||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
||||
return x
|
||||
|
||||
|
||||
class WindowAttention(nn.Module):
|
||||
""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
||||
It supports both of shifted and non-shifted window.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
window_size (tuple[int]): The height and width of the window.
|
||||
num_heads (int): Number of attention heads.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
||||
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
||||
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
window_size,
|
||||
num_heads,
|
||||
v_dim,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
attn_drop=0.,
|
||||
proj_drop=0.):
|
||||
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.window_size = window_size # Wh, Ww
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
# define a parameter table of relative position bias
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
|
||||
num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
||||
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = torch.arange(self.window_size[0])
|
||||
coords_w = torch.arange(self.window_size[1])
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||
relative_coords = coords_flatten[:, :,
|
||||
None] - coords_flatten[:,
|
||||
None, :] # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(
|
||||
1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
||||
relative_coords[:, :,
|
||||
0] += self.window_size[0] - 1 # shift to start from 0
|
||||
relative_coords[:, :, 1] += self.window_size[1] - 1
|
||||
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
||||
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
||||
self.register_buffer('relative_position_index',
|
||||
relative_position_index)
|
||||
|
||||
self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(v_dim, v_dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
trunc_normal_(self.relative_position_bias_table, std=.02)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def forward(self, x, v, mask=None):
|
||||
""" Forward function.
|
||||
|
||||
Args:
|
||||
x: input features with shape of (num_windows*B, N, C)
|
||||
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
||||
"""
|
||||
B_, N, C = x.shape
|
||||
qk = self.qk(x).reshape(B_, N, 2, self.num_heads,
|
||||
C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k = qk[0], qk[
|
||||
1] # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
q = q * self.scale
|
||||
attn = (q @ k.transpose(-2, -1))
|
||||
|
||||
relative_position_bias = self.relative_position_bias_table[
|
||||
self.relative_position_index.view(-1)].view(
|
||||
self.window_size[0] * self.window_size[1],
|
||||
self.window_size[0] * self.window_size[1],
|
||||
-1) # Wh*Ww,Wh*Ww,nH
|
||||
relative_position_bias = relative_position_bias.permute(
|
||||
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
|
||||
if mask is not None:
|
||||
nW = mask.shape[0]
|
||||
attn = attn.view(B_ // nW, nW, self.num_heads, N,
|
||||
N) + mask.unsqueeze(1).unsqueeze(0)
|
||||
attn = attn.view(-1, self.num_heads, N, N)
|
||||
attn = self.softmax(attn)
|
||||
else:
|
||||
attn = self.softmax(attn)
|
||||
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
# assert self.dim % v.shape[-1] == 0, "self.dim % v.shape[-1] != 0"
|
||||
# repeat_num = self.dim // v.shape[-1]
|
||||
# v = v.view(B_, N, self.num_heads // repeat_num, -1).transpose(1, 2).repeat(1, repeat_num, 1, 1)
|
||||
|
||||
assert self.dim == v.shape[-1], 'self.dim != v.shape[-1]'
|
||||
v = v.view(B_, N, self.num_heads, -1).transpose(1, 2)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class CRFBlock(nn.Module):
|
||||
""" CRF Block.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): Window size.
|
||||
shift_size (int): Shift size for SW-MSA.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
||||
drop (float, optional): Dropout rate. Default: 0.0
|
||||
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
||||
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
||||
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
num_heads,
|
||||
v_dim,
|
||||
window_size=7,
|
||||
shift_size=0,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.v_dim = v_dim
|
||||
self.window_size = window_size
|
||||
self.shift_size = shift_size
|
||||
self.mlp_ratio = mlp_ratio
|
||||
assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size'
|
||||
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = WindowAttention(
|
||||
dim,
|
||||
window_size=to_2tuple(self.window_size),
|
||||
num_heads=num_heads,
|
||||
v_dim=v_dim,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop)
|
||||
|
||||
self.drop_path = DropPath(
|
||||
drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(v_dim)
|
||||
mlp_hidden_dim = int(v_dim * mlp_ratio)
|
||||
self.mlp = Mlp(
|
||||
in_features=v_dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop)
|
||||
|
||||
self.H = None
|
||||
self.W = None
|
||||
|
||||
def forward(self, x, v, mask_matrix):
|
||||
""" Forward function.
|
||||
|
||||
Args:
|
||||
x: Input feature, tensor size (B, H*W, C).
|
||||
H, W: Spatial resolution of the input feature.
|
||||
mask_matrix: Attention mask for cyclic shift.
|
||||
"""
|
||||
B, L, C = x.shape
|
||||
H, W = self.H, self.W
|
||||
assert L == H * W, 'input feature has wrong size'
|
||||
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
x = x.view(B, H, W, C)
|
||||
|
||||
# pad feature maps to multiples of window size
|
||||
pad_l = pad_t = 0
|
||||
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
||||
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
||||
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
||||
v = F.pad(v, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
||||
_, Hp, Wp, _ = x.shape
|
||||
|
||||
# cyclic shift
|
||||
if self.shift_size > 0:
|
||||
shifted_x = torch.roll(
|
||||
x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
||||
shifted_v = torch.roll(
|
||||
v, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
||||
attn_mask = mask_matrix
|
||||
else:
|
||||
shifted_x = x
|
||||
shifted_v = v
|
||||
attn_mask = None
|
||||
|
||||
# partition windows
|
||||
x_windows = window_partition(
|
||||
shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
||||
x_windows = x_windows.view(-1, self.window_size * self.window_size,
|
||||
C) # nW*B, window_size*window_size, C
|
||||
v_windows = window_partition(
|
||||
shifted_v, self.window_size) # nW*B, window_size, window_size, C
|
||||
v_windows = v_windows.view(
|
||||
-1, self.window_size * self.window_size,
|
||||
v_windows.shape[-1]) # nW*B, window_size*window_size, C
|
||||
|
||||
# W-MSA/SW-MSA
|
||||
attn_windows = self.attn(
|
||||
x_windows, v_windows,
|
||||
mask=attn_mask) # nW*B, window_size*window_size, C
|
||||
|
||||
# merge windows
|
||||
attn_windows = attn_windows.view(-1, self.window_size,
|
||||
self.window_size, self.v_dim)
|
||||
shifted_x = window_reverse(attn_windows, self.window_size, Hp,
|
||||
Wp) # B H' W' C
|
||||
|
||||
# reverse cyclic shift
|
||||
if self.shift_size > 0:
|
||||
x = torch.roll(
|
||||
shifted_x,
|
||||
shifts=(self.shift_size, self.shift_size),
|
||||
dims=(1, 2))
|
||||
else:
|
||||
x = shifted_x
|
||||
|
||||
if pad_r > 0 or pad_b > 0:
|
||||
x = x[:, :H, :W, :].contiguous()
|
||||
|
||||
x = x.view(B, H * W, self.v_dim)
|
||||
|
||||
# FFN
|
||||
x = shortcut + self.drop_path(x)
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class BasicCRFLayer(nn.Module):
|
||||
""" A basic NeWCRFs layer for one stage.
|
||||
|
||||
Args:
|
||||
dim (int): Number of feature channels
|
||||
depth (int): Depths of this stage.
|
||||
num_heads (int): Number of attention head.
|
||||
window_size (int): Local window size. Default: 7.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
||||
drop (float, optional): Dropout rate. Default: 0.0
|
||||
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
||||
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
||||
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
depth,
|
||||
num_heads,
|
||||
v_dim,
|
||||
window_size=7,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
norm_layer=nn.LayerNorm,
|
||||
downsample=None,
|
||||
use_checkpoint=False):
|
||||
super().__init__()
|
||||
self.window_size = window_size
|
||||
self.shift_size = window_size // 2
|
||||
self.depth = depth
|
||||
self.use_checkpoint = use_checkpoint
|
||||
|
||||
# build blocks
|
||||
self.blocks = nn.ModuleList([
|
||||
CRFBlock(
|
||||
dim=dim,
|
||||
num_heads=num_heads,
|
||||
v_dim=v_dim,
|
||||
window_size=window_size,
|
||||
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=drop,
|
||||
attn_drop=attn_drop,
|
||||
drop_path=drop_path[i]
|
||||
if isinstance(drop_path, list) else drop_path,
|
||||
norm_layer=norm_layer) for i in range(depth)
|
||||
])
|
||||
|
||||
# patch merging layer
|
||||
if downsample is not None:
|
||||
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
|
||||
else:
|
||||
self.downsample = None
|
||||
|
||||
def forward(self, x, v, H, W):
|
||||
""" Forward function.
|
||||
|
||||
Args:
|
||||
x: Input feature, tensor size (B, H*W, C).
|
||||
H, W: Spatial resolution of the input feature.
|
||||
"""
|
||||
|
||||
# calculate attention mask for SW-MSA
|
||||
Hp = int(np.ceil(H / self.window_size)) * self.window_size
|
||||
Wp = int(np.ceil(W / self.window_size)) * self.window_size
|
||||
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
|
||||
h_slices = (slice(0, -self.window_size),
|
||||
slice(-self.window_size,
|
||||
-self.shift_size), slice(-self.shift_size, None))
|
||||
w_slices = (slice(0, -self.window_size),
|
||||
slice(-self.window_size,
|
||||
-self.shift_size), slice(-self.shift_size, None))
|
||||
cnt = 0
|
||||
for h in h_slices:
|
||||
for w in w_slices:
|
||||
img_mask[:, h, w, :] = cnt
|
||||
cnt += 1
|
||||
|
||||
mask_windows = window_partition(
|
||||
img_mask, self.window_size) # nW, window_size, window_size, 1
|
||||
mask_windows = mask_windows.view(-1,
|
||||
self.window_size * self.window_size)
|
||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||
attn_mask = attn_mask.masked_fill(attn_mask != 0,
|
||||
float(-100.0)).masked_fill(
|
||||
attn_mask == 0, float(0.0))
|
||||
|
||||
for blk in self.blocks:
|
||||
blk.H, blk.W = H, W
|
||||
if self.use_checkpoint:
|
||||
x = checkpoint.checkpoint(blk, x, attn_mask)
|
||||
else:
|
||||
x = blk(x, v, attn_mask)
|
||||
if self.downsample is not None:
|
||||
x_down = self.downsample(x, H, W)
|
||||
Wh, Ww = (H + 1) // 2, (W + 1) // 2
|
||||
return x, H, W, x_down, Wh, Ww
|
||||
else:
|
||||
return x, H, W, x, H, W
|
||||
|
||||
|
||||
class NewCRF(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
input_dim=96,
|
||||
embed_dim=96,
|
||||
v_dim=64,
|
||||
window_size=7,
|
||||
num_heads=4,
|
||||
depth=2,
|
||||
patch_size=4,
|
||||
in_chans=3,
|
||||
norm_layer=nn.LayerNorm,
|
||||
patch_norm=True):
|
||||
super().__init__()
|
||||
|
||||
self.embed_dim = embed_dim
|
||||
self.patch_norm = patch_norm
|
||||
|
||||
if input_dim != embed_dim:
|
||||
self.proj_x = nn.Conv2d(input_dim, embed_dim, 3, padding=1)
|
||||
else:
|
||||
self.proj_x = None
|
||||
|
||||
if v_dim != embed_dim:
|
||||
self.proj_v = nn.Conv2d(v_dim, embed_dim, 3, padding=1)
|
||||
elif embed_dim % v_dim == 0:
|
||||
self.proj_v = None
|
||||
|
||||
v_dim = embed_dim
|
||||
assert v_dim == embed_dim
|
||||
|
||||
self.crf_layer = BasicCRFLayer(
|
||||
dim=embed_dim,
|
||||
depth=depth,
|
||||
num_heads=num_heads,
|
||||
v_dim=v_dim,
|
||||
window_size=window_size,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
norm_layer=norm_layer,
|
||||
downsample=None,
|
||||
use_checkpoint=False)
|
||||
|
||||
layer = norm_layer(embed_dim)
|
||||
layer_name = 'norm_crf'
|
||||
self.add_module(layer_name, layer)
|
||||
|
||||
def forward(self, x, v):
|
||||
if self.proj_x is not None:
|
||||
x = self.proj_x(x)
|
||||
if self.proj_v is not None:
|
||||
v = self.proj_v(v)
|
||||
|
||||
Wh, Ww = x.size(2), x.size(3)
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
v = v.transpose(1, 2).transpose(2, 3)
|
||||
|
||||
x_out, H, W, x, Wh, Ww = self.crf_layer(x, v, Wh, Ww)
|
||||
norm_layer = getattr(self, 'norm_crf')
|
||||
x_out = norm_layer(x_out)
|
||||
out = x_out.view(-1, H, W, self.embed_dim).permute(0, 3, 1,
|
||||
2).contiguous()
|
||||
|
||||
return out
|
||||
@@ -0,0 +1,272 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import os.path as osp
|
||||
import pkgutil
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from importlib import import_module
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
from torch import distributed as dist
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
||||
from torch.utils import model_zoo
|
||||
|
||||
TORCH_VERSION = torch.__version__
|
||||
|
||||
|
||||
def resize(input,
|
||||
size=None,
|
||||
scale_factor=None,
|
||||
mode='nearest',
|
||||
align_corners=None,
|
||||
warning=True):
|
||||
if warning:
|
||||
if size is not None and align_corners:
|
||||
input_h, input_w = tuple(int(x) for x in input.shape[2:])
|
||||
output_h, output_w = tuple(int(x) for x in size)
|
||||
if output_h > input_h or output_w > output_h:
|
||||
if ((output_h > 1 and output_w > 1 and input_h > 1
|
||||
and input_w > 1) and (output_h - 1) % (input_h - 1)
|
||||
and (output_w - 1) % (input_w - 1)):
|
||||
warnings.warn(
|
||||
f'When align_corners={align_corners}, '
|
||||
'the output would more aligned if '
|
||||
f'input size {(input_h, input_w)} is `x+1` and '
|
||||
f'out size {(output_h, output_w)} is `nx+1`')
|
||||
if isinstance(size, torch.Size):
|
||||
size = tuple(int(x) for x in size)
|
||||
return F.interpolate(input, size, scale_factor, mode, align_corners)
|
||||
|
||||
|
||||
def normal_init(module, mean=0, std=1, bias=0):
|
||||
if hasattr(module, 'weight') and module.weight is not None:
|
||||
nn.init.normal_(module.weight, mean, std)
|
||||
if hasattr(module, 'bias') and module.bias is not None:
|
||||
nn.init.constant_(module.bias, bias)
|
||||
|
||||
|
||||
def is_module_wrapper(module):
|
||||
module_wrappers = (DataParallel, DistributedDataParallel)
|
||||
return isinstance(module, module_wrappers)
|
||||
|
||||
|
||||
def get_dist_info():
|
||||
if TORCH_VERSION < '1.0':
|
||||
initialized = dist._initialized
|
||||
else:
|
||||
if dist.is_available():
|
||||
initialized = dist.is_initialized()
|
||||
else:
|
||||
initialized = False
|
||||
if initialized:
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
else:
|
||||
rank = 0
|
||||
world_size = 1
|
||||
return rank, world_size
|
||||
|
||||
|
||||
def load_state_dict(module, state_dict, strict=False, logger=None):
|
||||
"""Load state_dict to a module.
|
||||
|
||||
This method is modified from :meth:`torch.nn.Module.load_state_dict`.
|
||||
Default value for ``strict`` is set to ``False`` and the message for
|
||||
param mismatch will be shown even if strict is False.
|
||||
|
||||
Args:
|
||||
module (Module): Module that receives the state_dict.
|
||||
state_dict (OrderedDict): Weights.
|
||||
strict (bool): whether to strictly enforce that the keys
|
||||
in :attr:`state_dict` match the keys returned by this module's
|
||||
:meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
|
||||
logger (:obj:`logging.Logger`, optional): Logger to log the error
|
||||
message. If not specified, print function will be used.
|
||||
"""
|
||||
unexpected_keys = []
|
||||
all_missing_keys = []
|
||||
err_msg = []
|
||||
|
||||
metadata = getattr(state_dict, '_metadata', None)
|
||||
state_dict = state_dict.copy()
|
||||
if metadata is not None:
|
||||
state_dict._metadata = metadata
|
||||
|
||||
# use _load_from_state_dict to enable checkpoint version control
|
||||
def load(module, prefix=''):
|
||||
# recursively check parallel module in case that the model has a
|
||||
# complicated structure, e.g., nn.Module(nn.Module(DDP))
|
||||
if is_module_wrapper(module):
|
||||
module = module.module
|
||||
local_metadata = {} if metadata is None else metadata.get(
|
||||
prefix[:-1], {})
|
||||
module._load_from_state_dict(state_dict, prefix, local_metadata, True,
|
||||
all_missing_keys, unexpected_keys,
|
||||
err_msg)
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, prefix + name + '.')
|
||||
|
||||
load(module)
|
||||
load = None # break load->load reference cycle
|
||||
|
||||
# ignore "num_batches_tracked" of BN layers
|
||||
missing_keys = [
|
||||
key for key in all_missing_keys if 'num_batches_tracked' not in key
|
||||
]
|
||||
|
||||
if unexpected_keys:
|
||||
err_msg.append('unexpected key in source '
|
||||
f'state_dict: {", ".join(unexpected_keys)}\n')
|
||||
if missing_keys:
|
||||
err_msg.append(
|
||||
f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
|
||||
|
||||
rank, _ = get_dist_info()
|
||||
if len(err_msg) > 0 and rank == 0:
|
||||
err_msg.insert(
|
||||
0, 'The model and loaded state dict do not match exactly\n')
|
||||
err_msg = '\n'.join(err_msg)
|
||||
if strict:
|
||||
raise RuntimeError(err_msg)
|
||||
elif logger is not None:
|
||||
logger.warning(err_msg)
|
||||
else:
|
||||
print(err_msg)
|
||||
|
||||
|
||||
def load_url_dist(url, model_dir=None):
|
||||
"""In distributed setting, this function only download checkpoint at local
|
||||
rank 0."""
|
||||
rank, world_size = get_dist_info()
|
||||
rank = int(os.environ.get('LOCAL_RANK', rank))
|
||||
if rank == 0:
|
||||
checkpoint = model_zoo.load_url(url, model_dir=model_dir)
|
||||
if world_size > 1:
|
||||
torch.distributed.barrier()
|
||||
if rank > 0:
|
||||
checkpoint = model_zoo.load_url(url, model_dir=model_dir)
|
||||
return checkpoint
|
||||
|
||||
|
||||
def get_torchvision_models():
|
||||
model_urls = dict()
|
||||
for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
|
||||
if ispkg:
|
||||
continue
|
||||
_zoo = import_module(f'torchvision.models.{name}')
|
||||
if hasattr(_zoo, 'model_urls'):
|
||||
_urls = getattr(_zoo, 'model_urls')
|
||||
model_urls.update(_urls)
|
||||
return model_urls
|
||||
|
||||
|
||||
def _load_checkpoint(filename, map_location=None):
|
||||
"""Load checkpoint from somewhere (modelzoo, file, url).
|
||||
|
||||
Args:
|
||||
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
|
||||
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
|
||||
details.
|
||||
map_location (str | None): Same as :func:`torch.load`. Default: None.
|
||||
|
||||
Returns:
|
||||
dict | OrderedDict: The loaded checkpoint. It can be either an
|
||||
OrderedDict storing model weights or a dict containing other
|
||||
information, which depends on the checkpoint.
|
||||
"""
|
||||
if filename.startswith('modelzoo://'):
|
||||
warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
|
||||
'use "torchvision://" instead')
|
||||
model_urls = get_torchvision_models()
|
||||
model_name = filename[11:]
|
||||
checkpoint = load_url_dist(model_urls[model_name])
|
||||
else:
|
||||
if not osp.isfile(filename):
|
||||
raise IOError(f'{filename} is not a checkpoint file')
|
||||
checkpoint = torch.load(filename, map_location=map_location)
|
||||
return checkpoint
|
||||
|
||||
|
||||
def load_checkpoint(model,
|
||||
filename,
|
||||
map_location='cpu',
|
||||
strict=False,
|
||||
logger=None):
|
||||
"""Load checkpoint from a file or URI.
|
||||
|
||||
Args:
|
||||
model (Module): Module to load checkpoint.
|
||||
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
|
||||
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
|
||||
details.
|
||||
map_location (str): Same as :func:`torch.load`.
|
||||
strict (bool): Whether to allow different params for the model and
|
||||
checkpoint.
|
||||
logger (:mod:`logging.Logger` or None): The logger for error message.
|
||||
|
||||
Returns:
|
||||
dict or OrderedDict: The loaded checkpoint.
|
||||
"""
|
||||
checkpoint = _load_checkpoint(filename, map_location)
|
||||
# OrderedDict is a subclass of dict
|
||||
if not isinstance(checkpoint, dict):
|
||||
raise RuntimeError(
|
||||
f'No state_dict found in checkpoint file {filename}')
|
||||
# get state_dict from checkpoint
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
elif 'model' in checkpoint:
|
||||
state_dict = checkpoint['model']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
# strip prefix of state_dict
|
||||
if list(state_dict.keys())[0].startswith('module.'):
|
||||
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
||||
|
||||
# for MoBY, load model of online branch
|
||||
if sorted(list(state_dict.keys()))[0].startswith('encoder'):
|
||||
state_dict = {
|
||||
k.replace('encoder.', ''): v
|
||||
for k, v in state_dict.items() if k.startswith('encoder.')
|
||||
}
|
||||
|
||||
# reshape absolute position embedding
|
||||
if state_dict.get('absolute_pos_embed') is not None:
|
||||
absolute_pos_embed = state_dict['absolute_pos_embed']
|
||||
N1, L, C1 = absolute_pos_embed.size()
|
||||
N2, C2, H, W = model.absolute_pos_embed.size()
|
||||
if N1 != N2 or C1 != C2 or L != H * W:
|
||||
logger.warning('Error in loading absolute_pos_embed, pass')
|
||||
else:
|
||||
state_dict['absolute_pos_embed'] = absolute_pos_embed.view(
|
||||
N2, H, W, C2).permute(0, 3, 1, 2)
|
||||
|
||||
# interpolate position bias table if needed
|
||||
relative_position_bias_table_keys = [
|
||||
k for k in state_dict.keys() if 'relative_position_bias_table' in k
|
||||
]
|
||||
for table_key in relative_position_bias_table_keys:
|
||||
table_pretrained = state_dict[table_key]
|
||||
table_current = model.state_dict()[table_key]
|
||||
L1, nH1 = table_pretrained.size()
|
||||
L2, nH2 = table_current.size()
|
||||
if nH1 != nH2:
|
||||
logger.warning(f'Error in loading {table_key}, pass')
|
||||
else:
|
||||
if L1 != L2:
|
||||
S1 = int(L1**0.5)
|
||||
S2 = int(L2**0.5)
|
||||
table_pretrained_resized = F.interpolate(
|
||||
table_pretrained.permute(1, 0).view(1, nH1, S1, S1),
|
||||
size=(S2, S2),
|
||||
mode='bicubic')
|
||||
state_dict[table_key] = table_pretrained_resized.view(
|
||||
nH2, L2).permute(1, 0)
|
||||
|
||||
# load state_dict
|
||||
load_state_dict(model, state_dict, strict, logger)
|
||||
return checkpoint
|
||||
@@ -0,0 +1,706 @@
|
||||
# The implementation is adopted from Swin Transformer
|
||||
# made publicly available under the MIT License at https://github.com/microsoft/Swin-Transformer
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
||||
|
||||
from .newcrf_utils import load_checkpoint
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
""" Multilayer perceptron."""
|
||||
|
||||
def __init__(self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.GELU,
|
||||
drop=0.):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
def window_partition(x, window_size):
|
||||
"""
|
||||
Args:
|
||||
x: (B, H, W, C)
|
||||
window_size (int): window size
|
||||
|
||||
Returns:
|
||||
windows: (num_windows*B, window_size, window_size, C)
|
||||
"""
|
||||
B, H, W, C = x.shape
|
||||
x = x.view(B, H // window_size, window_size, W // window_size, window_size,
|
||||
C)
|
||||
windows = x.permute(0, 1, 3, 2, 4,
|
||||
5).contiguous().view(-1, window_size, window_size, C)
|
||||
return windows
|
||||
|
||||
|
||||
def window_reverse(windows, window_size, H, W):
|
||||
"""
|
||||
Args:
|
||||
windows: (num_windows*B, window_size, window_size, C)
|
||||
window_size (int): Window size
|
||||
H (int): Height of image
|
||||
W (int): Width of image
|
||||
|
||||
Returns:
|
||||
x: (B, H, W, C)
|
||||
"""
|
||||
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
||||
x = windows.view(B, H // window_size, W // window_size, window_size,
|
||||
window_size, -1)
|
||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
||||
return x
|
||||
|
||||
|
||||
class WindowAttention(nn.Module):
|
||||
""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
||||
It supports both of shifted and non-shifted window.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
window_size (tuple[int]): The height and width of the window.
|
||||
num_heads (int): Number of attention heads.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
||||
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
||||
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
window_size,
|
||||
num_heads,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
attn_drop=0.,
|
||||
proj_drop=0.):
|
||||
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.window_size = window_size # Wh, Ww
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
# define a parameter table of relative position bias
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
|
||||
num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
||||
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = torch.arange(self.window_size[0])
|
||||
coords_w = torch.arange(self.window_size[1])
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||
relative_coords = coords_flatten[:, :,
|
||||
None] - coords_flatten[:,
|
||||
None, :] # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(
|
||||
1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
||||
relative_coords[:, :,
|
||||
0] += self.window_size[0] - 1 # shift to start from 0
|
||||
relative_coords[:, :, 1] += self.window_size[1] - 1
|
||||
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
||||
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
||||
self.register_buffer('relative_position_index',
|
||||
relative_position_index)
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
trunc_normal_(self.relative_position_bias_table, std=.02)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
""" Forward function.
|
||||
|
||||
Args:
|
||||
x: input features with shape of (num_windows*B, N, C)
|
||||
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
||||
"""
|
||||
B_, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads,
|
||||
C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[
|
||||
2] # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
q = q * self.scale
|
||||
attn = (q @ k.transpose(-2, -1))
|
||||
|
||||
relative_position_bias = self.relative_position_bias_table[
|
||||
self.relative_position_index.view(-1)].view(
|
||||
self.window_size[0] * self.window_size[1],
|
||||
self.window_size[0] * self.window_size[1],
|
||||
-1) # Wh*Ww,Wh*Ww,nH
|
||||
relative_position_bias = relative_position_bias.permute(
|
||||
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
|
||||
if mask is not None:
|
||||
nW = mask.shape[0]
|
||||
attn = attn.view(B_ // nW, nW, self.num_heads, N,
|
||||
N) + mask.unsqueeze(1).unsqueeze(0)
|
||||
attn = attn.view(-1, self.num_heads, N, N)
|
||||
attn = self.softmax(attn)
|
||||
else:
|
||||
attn = self.softmax(attn)
|
||||
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class SwinTransformerBlock(nn.Module):
|
||||
""" Swin Transformer Block.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): Window size.
|
||||
shift_size (int): Shift size for SW-MSA.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
||||
drop (float, optional): Dropout rate. Default: 0.0
|
||||
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
||||
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
||||
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
num_heads,
|
||||
window_size=7,
|
||||
shift_size=0,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.window_size = window_size
|
||||
self.shift_size = shift_size
|
||||
self.mlp_ratio = mlp_ratio
|
||||
assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size'
|
||||
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = WindowAttention(
|
||||
dim,
|
||||
window_size=to_2tuple(self.window_size),
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop)
|
||||
|
||||
self.drop_path = DropPath(
|
||||
drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop)
|
||||
|
||||
self.H = None
|
||||
self.W = None
|
||||
|
||||
def forward(self, x, mask_matrix):
|
||||
""" Forward function.
|
||||
|
||||
Args:
|
||||
x: Input feature, tensor size (B, H*W, C).
|
||||
H, W: Spatial resolution of the input feature.
|
||||
mask_matrix: Attention mask for cyclic shift.
|
||||
"""
|
||||
B, L, C = x.shape
|
||||
H, W = self.H, self.W
|
||||
assert L == H * W, 'input feature has wrong size'
|
||||
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
x = x.view(B, H, W, C)
|
||||
|
||||
# pad feature maps to multiples of window size
|
||||
pad_l = pad_t = 0
|
||||
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
||||
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
||||
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
||||
_, Hp, Wp, _ = x.shape
|
||||
|
||||
# cyclic shift
|
||||
if self.shift_size > 0:
|
||||
shifted_x = torch.roll(
|
||||
x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
||||
attn_mask = mask_matrix
|
||||
else:
|
||||
shifted_x = x
|
||||
attn_mask = None
|
||||
|
||||
# partition windows
|
||||
x_windows = window_partition(
|
||||
shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
||||
x_windows = x_windows.view(-1, self.window_size * self.window_size,
|
||||
C) # nW*B, window_size*window_size, C
|
||||
|
||||
# W-MSA/SW-MSA
|
||||
attn_windows = self.attn(
|
||||
x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
|
||||
|
||||
# merge windows
|
||||
attn_windows = attn_windows.view(-1, self.window_size,
|
||||
self.window_size, C)
|
||||
shifted_x = window_reverse(attn_windows, self.window_size, Hp,
|
||||
Wp) # B H' W' C
|
||||
|
||||
# reverse cyclic shift
|
||||
if self.shift_size > 0:
|
||||
x = torch.roll(
|
||||
shifted_x,
|
||||
shifts=(self.shift_size, self.shift_size),
|
||||
dims=(1, 2))
|
||||
else:
|
||||
x = shifted_x
|
||||
|
||||
if pad_r > 0 or pad_b > 0:
|
||||
x = x[:, :H, :W, :].contiguous()
|
||||
|
||||
x = x.view(B, H * W, C)
|
||||
|
||||
# FFN
|
||||
x = shortcut + self.drop_path(x)
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class PatchMerging(nn.Module):
|
||||
""" Patch Merging Layer
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
"""
|
||||
|
||||
def __init__(self, dim, norm_layer=nn.LayerNorm):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
||||
self.norm = norm_layer(4 * dim)
|
||||
|
||||
def forward(self, x, H, W):
|
||||
""" Forward function.
|
||||
|
||||
Args:
|
||||
x: Input feature, tensor size (B, H*W, C).
|
||||
H, W: Spatial resolution of the input feature.
|
||||
"""
|
||||
B, L, C = x.shape
|
||||
assert L == H * W, 'input feature has wrong size'
|
||||
|
||||
x = x.view(B, H, W, C)
|
||||
|
||||
# padding
|
||||
pad_input = (H % 2 == 1) or (W % 2 == 1)
|
||||
if pad_input:
|
||||
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
|
||||
|
||||
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
||||
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
||||
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
||||
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
||||
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
||||
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
||||
|
||||
x = self.norm(x)
|
||||
x = self.reduction(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class BasicLayer(nn.Module):
|
||||
""" A basic Swin Transformer layer for one stage.
|
||||
|
||||
Args:
|
||||
dim (int): Number of feature channels
|
||||
depth (int): Depths of this stage.
|
||||
num_heads (int): Number of attention head.
|
||||
window_size (int): Local window size. Default: 7.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
||||
drop (float, optional): Dropout rate. Default: 0.0
|
||||
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
||||
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
||||
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
depth,
|
||||
num_heads,
|
||||
window_size=7,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
norm_layer=nn.LayerNorm,
|
||||
downsample=None,
|
||||
use_checkpoint=False):
|
||||
super().__init__()
|
||||
self.window_size = window_size
|
||||
self.shift_size = window_size // 2
|
||||
self.depth = depth
|
||||
self.use_checkpoint = use_checkpoint
|
||||
|
||||
# build blocks
|
||||
self.blocks = nn.ModuleList([
|
||||
SwinTransformerBlock(
|
||||
dim=dim,
|
||||
num_heads=num_heads,
|
||||
window_size=window_size,
|
||||
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=drop,
|
||||
attn_drop=attn_drop,
|
||||
drop_path=drop_path[i]
|
||||
if isinstance(drop_path, list) else drop_path,
|
||||
norm_layer=norm_layer) for i in range(depth)
|
||||
])
|
||||
|
||||
# patch merging layer
|
||||
if downsample is not None:
|
||||
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
|
||||
else:
|
||||
self.downsample = None
|
||||
|
||||
def forward(self, x, H, W):
|
||||
""" Forward function.
|
||||
|
||||
Args:
|
||||
x: Input feature, tensor size (B, H*W, C).
|
||||
H, W: Spatial resolution of the input feature.
|
||||
"""
|
||||
|
||||
# calculate attention mask for SW-MSA
|
||||
Hp = int(np.ceil(H / self.window_size)) * self.window_size
|
||||
Wp = int(np.ceil(W / self.window_size)) * self.window_size
|
||||
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
|
||||
h_slices = (slice(0, -self.window_size),
|
||||
slice(-self.window_size,
|
||||
-self.shift_size), slice(-self.shift_size, None))
|
||||
w_slices = (slice(0, -self.window_size),
|
||||
slice(-self.window_size,
|
||||
-self.shift_size), slice(-self.shift_size, None))
|
||||
cnt = 0
|
||||
for h in h_slices:
|
||||
for w in w_slices:
|
||||
img_mask[:, h, w, :] = cnt
|
||||
cnt += 1
|
||||
|
||||
mask_windows = window_partition(
|
||||
img_mask, self.window_size) # nW, window_size, window_size, 1
|
||||
mask_windows = mask_windows.view(-1,
|
||||
self.window_size * self.window_size)
|
||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||
attn_mask = attn_mask.masked_fill(attn_mask != 0,
|
||||
float(-100.0)).masked_fill(
|
||||
attn_mask == 0, float(0.0))
|
||||
|
||||
for blk in self.blocks:
|
||||
blk.H, blk.W = H, W
|
||||
if self.use_checkpoint:
|
||||
x = checkpoint.checkpoint(blk, x, attn_mask)
|
||||
else:
|
||||
x = blk(x, attn_mask)
|
||||
if self.downsample is not None:
|
||||
x_down = self.downsample(x, H, W)
|
||||
Wh, Ww = (H + 1) // 2, (W + 1) // 2
|
||||
return x, H, W, x_down, Wh, Ww
|
||||
else:
|
||||
return x, H, W, x, H, W
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
""" Image to Patch Embedding
|
||||
|
||||
Args:
|
||||
patch_size (int): Patch token size. Default: 4.
|
||||
in_chans (int): Number of input image channels. Default: 3.
|
||||
embed_dim (int): Number of linear projection output channels. Default: 96.
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
patch_size=4,
|
||||
in_chans=3,
|
||||
embed_dim=96,
|
||||
norm_layer=None):
|
||||
super().__init__()
|
||||
patch_size = to_2tuple(patch_size)
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.proj = nn.Conv2d(
|
||||
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
if norm_layer is not None:
|
||||
self.norm = norm_layer(embed_dim)
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
# padding
|
||||
_, _, H, W = x.size()
|
||||
if W % self.patch_size[1] != 0:
|
||||
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
|
||||
if H % self.patch_size[0] != 0:
|
||||
x = F.pad(x,
|
||||
(0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
|
||||
|
||||
x = self.proj(x) # B C Wh Ww
|
||||
if self.norm is not None:
|
||||
Wh, Ww = x.size(2), x.size(3)
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = self.norm(x)
|
||||
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SwinTransformer(nn.Module):
|
||||
""" Swin Transformer backbone.
|
||||
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
|
||||
https://arxiv.org/pdf/2103.14030
|
||||
|
||||
Args:
|
||||
pretrain_img_size (int): Input image size for training the pretrained model,
|
||||
used in absolute postion embedding. Default 224.
|
||||
patch_size (int | tuple(int)): Patch size. Default: 4.
|
||||
in_chans (int): Number of input image channels. Default: 3.
|
||||
embed_dim (int): Number of linear projection output channels. Default: 96.
|
||||
depths (tuple[int]): Depths of each Swin Transformer stage.
|
||||
num_heads (tuple[int]): Number of attention head of each stage.
|
||||
window_size (int): Window size. Default: 7.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
||||
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
|
||||
drop_rate (float): Dropout rate.
|
||||
attn_drop_rate (float): Attention dropout rate. Default: 0.
|
||||
drop_path_rate (float): Stochastic depth rate. Default: 0.2.
|
||||
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
||||
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
|
||||
patch_norm (bool): If True, add normalization after patch embedding. Default: True.
|
||||
out_indices (Sequence[int]): Output from which stages.
|
||||
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
||||
-1 means not freezing any parameters.
|
||||
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrain_img_size=224,
|
||||
patch_size=4,
|
||||
in_chans=3,
|
||||
embed_dim=96,
|
||||
depths=[2, 2, 6, 2],
|
||||
num_heads=[3, 6, 12, 24],
|
||||
window_size=7,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.2,
|
||||
norm_layer=nn.LayerNorm,
|
||||
ape=False,
|
||||
patch_norm=True,
|
||||
out_indices=(0, 1, 2, 3),
|
||||
frozen_stages=-1,
|
||||
use_checkpoint=False):
|
||||
super().__init__()
|
||||
|
||||
self.pretrain_img_size = pretrain_img_size
|
||||
self.num_layers = len(depths)
|
||||
self.embed_dim = embed_dim
|
||||
self.ape = ape
|
||||
self.patch_norm = patch_norm
|
||||
self.out_indices = out_indices
|
||||
self.frozen_stages = frozen_stages
|
||||
|
||||
# split image into non-overlapping patches
|
||||
self.patch_embed = PatchEmbed(
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
norm_layer=norm_layer if self.patch_norm else None)
|
||||
|
||||
# absolute position embedding
|
||||
if self.ape:
|
||||
pretrain_img_size = to_2tuple(pretrain_img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
patches_resolution = [
|
||||
pretrain_img_size[0] // patch_size[0],
|
||||
pretrain_img_size[1] // patch_size[1]
|
||||
]
|
||||
|
||||
self.absolute_pos_embed = nn.Parameter(
|
||||
torch.zeros(1, embed_dim, patches_resolution[0],
|
||||
patches_resolution[1]))
|
||||
trunc_normal_(self.absolute_pos_embed, std=.02)
|
||||
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
# stochastic depth
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
|
||||
] # stochastic depth decay rule
|
||||
|
||||
# build layers
|
||||
self.layers = nn.ModuleList()
|
||||
for i_layer in range(self.num_layers):
|
||||
layer = BasicLayer(
|
||||
dim=int(embed_dim * 2**i_layer),
|
||||
depth=depths[i_layer],
|
||||
num_heads=num_heads[i_layer],
|
||||
window_size=window_size,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
||||
norm_layer=norm_layer,
|
||||
downsample=PatchMerging if
|
||||
(i_layer < self.num_layers - 1) else None,
|
||||
use_checkpoint=use_checkpoint)
|
||||
self.layers.append(layer)
|
||||
|
||||
num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
|
||||
self.num_features = num_features
|
||||
|
||||
# add a norm layer for each output
|
||||
for i_layer in out_indices:
|
||||
layer = norm_layer(num_features[i_layer])
|
||||
layer_name = f'norm{i_layer}'
|
||||
self.add_module(layer_name, layer)
|
||||
|
||||
self._freeze_stages()
|
||||
|
||||
def _freeze_stages(self):
|
||||
if self.frozen_stages >= 0:
|
||||
self.patch_embed.eval()
|
||||
for param in self.patch_embed.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
if self.frozen_stages >= 1 and self.ape:
|
||||
self.absolute_pos_embed.requires_grad = False
|
||||
|
||||
if self.frozen_stages >= 2:
|
||||
self.pos_drop.eval()
|
||||
for i in range(0, self.frozen_stages - 1):
|
||||
m = self.layers[i]
|
||||
m.eval()
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
"""Initialize the weights in backbone.
|
||||
|
||||
Args:
|
||||
pretrained (str, optional): Path to pre-trained weights.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def _init_weights(m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
if isinstance(pretrained, str):
|
||||
self.apply(_init_weights)
|
||||
# logger = get_root_logger()
|
||||
load_checkpoint(self, pretrained, strict=False)
|
||||
elif pretrained is None:
|
||||
self.apply(_init_weights)
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
x = self.patch_embed(x)
|
||||
|
||||
Wh, Ww = x.size(2), x.size(3)
|
||||
if self.ape:
|
||||
# interpolate the position embedding to the corresponding size
|
||||
absolute_pos_embed = F.interpolate(
|
||||
self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
|
||||
x = (x + absolute_pos_embed).flatten(2).transpose(1,
|
||||
2) # B Wh*Ww C
|
||||
else:
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = self.pos_drop(x)
|
||||
|
||||
outs = []
|
||||
for i in range(self.num_layers):
|
||||
layer = self.layers[i]
|
||||
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
|
||||
|
||||
if i in self.out_indices:
|
||||
norm_layer = getattr(self, f'norm{i}')
|
||||
x_out = norm_layer(x_out)
|
||||
|
||||
out = x_out.view(-1, H, W,
|
||||
self.num_features[i]).permute(0, 3, 1,
|
||||
2).contiguous()
|
||||
outs.append(out)
|
||||
|
||||
return tuple(outs)
|
||||
|
||||
def train(self, mode=True):
|
||||
"""Convert the model into training mode while keep layers freezed."""
|
||||
super(SwinTransformer, self).train(mode)
|
||||
self._freeze_stages()
|
||||
@@ -0,0 +1,365 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from .newcrf_utils import normal_init, resize
|
||||
|
||||
|
||||
class PPM(nn.ModuleList):
|
||||
"""Pooling Pyramid Module used in PSPNet.
|
||||
|
||||
Args:
|
||||
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
||||
Module.
|
||||
in_channels (int): Input channels.
|
||||
channels (int): Channels after modules, before conv_seg.
|
||||
conv_cfg (dict|None): Config of conv layers.
|
||||
norm_cfg (dict|None): Config of norm layers.
|
||||
act_cfg (dict): Config of activation layers.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
"""
|
||||
|
||||
def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg,
|
||||
act_cfg, align_corners):
|
||||
super(PPM, self).__init__()
|
||||
self.pool_scales = pool_scales
|
||||
self.align_corners = align_corners
|
||||
self.in_channels = in_channels
|
||||
self.channels = channels
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
for pool_scale in pool_scales:
|
||||
# == if batch size = 1, BN is not supported, change to GN
|
||||
if pool_scale == 1:
|
||||
norm_cfg = dict(type='GN', requires_grad=True, num_groups=256)
|
||||
self.append(
|
||||
nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d(pool_scale),
|
||||
ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=self.act_cfg)))
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
ppm_outs = []
|
||||
for ppm in self:
|
||||
ppm_out = ppm(x)
|
||||
upsampled_ppm_out = resize(
|
||||
ppm_out,
|
||||
size=x.size()[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
ppm_outs.append(upsampled_ppm_out)
|
||||
return ppm_outs
|
||||
|
||||
|
||||
class BaseDecodeHead(nn.Module):
|
||||
"""Base class for BaseDecodeHead.
|
||||
|
||||
Args:
|
||||
in_channels (int|Sequence[int]): Input channels.
|
||||
channels (int): Channels after modules, before conv_seg.
|
||||
num_classes (int): Number of classes.
|
||||
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
|
||||
conv_cfg (dict|None): Config of conv layers. Default: None.
|
||||
norm_cfg (dict|None): Config of norm layers. Default: None.
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU')
|
||||
in_index (int|Sequence[int]): Input feature index. Default: -1
|
||||
input_transform (str|None): Transformation type of input features.
|
||||
Options: 'resize_concat', 'multiple_select', None.
|
||||
'resize_concat': Multiple feature maps will be resize to the
|
||||
same size as first one and than concat together.
|
||||
Usually used in FCN head of HRNet.
|
||||
'multiple_select': Multiple feature maps will be bundle into
|
||||
a list and passed into decode head.
|
||||
None: Only one select feature map is allowed.
|
||||
Default: None.
|
||||
loss_decode (dict): Config of decode loss.
|
||||
Default: dict(type='CrossEntropyLoss').
|
||||
ignore_index (int | None): The label index to be ignored. When using
|
||||
masked BCE loss, ignore_index should be set to None. Default: 255
|
||||
sampler (dict|None): The config of segmentation map sampler.
|
||||
Default: None.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
channels,
|
||||
*,
|
||||
num_classes,
|
||||
dropout_ratio=0.1,
|
||||
conv_cfg=None,
|
||||
norm_cfg=None,
|
||||
act_cfg=dict(type='ReLU'),
|
||||
in_index=-1,
|
||||
input_transform=None,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=False,
|
||||
loss_weight=1.0),
|
||||
ignore_index=255,
|
||||
sampler=None,
|
||||
align_corners=False):
|
||||
super(BaseDecodeHead, self).__init__()
|
||||
self._init_inputs(in_channels, in_index, input_transform)
|
||||
self.channels = channels
|
||||
self.num_classes = num_classes
|
||||
self.dropout_ratio = dropout_ratio
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.in_index = in_index
|
||||
# self.loss_decode = build_loss(loss_decode)
|
||||
self.ignore_index = ignore_index
|
||||
self.align_corners = align_corners
|
||||
# if sampler is not None:
|
||||
# self.sampler = build_pixel_sampler(sampler, context=self)
|
||||
# else:
|
||||
# self.sampler = None
|
||||
|
||||
# self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
|
||||
# self.conv1 = nn.Conv2d(channels, num_classes, 3, padding=1)
|
||||
if dropout_ratio > 0:
|
||||
self.dropout = nn.Dropout2d(dropout_ratio)
|
||||
else:
|
||||
self.dropout = None
|
||||
self.fp16_enabled = False
|
||||
|
||||
def extra_repr(self):
|
||||
"""Extra repr."""
|
||||
s = f'input_transform={self.input_transform}, ' \
|
||||
f'ignore_index={self.ignore_index}, ' \
|
||||
f'align_corners={self.align_corners}'
|
||||
return s
|
||||
|
||||
def _init_inputs(self, in_channels, in_index, input_transform):
|
||||
"""Check and initialize input transforms.
|
||||
|
||||
The in_channels, in_index and input_transform must match.
|
||||
Specifically, when input_transform is None, only single feature map
|
||||
will be selected. So in_channels and in_index must be of type int.
|
||||
When input_transform
|
||||
|
||||
Args:
|
||||
in_channels (int|Sequence[int]): Input channels.
|
||||
in_index (int|Sequence[int]): Input feature index.
|
||||
input_transform (str|None): Transformation type of input features.
|
||||
Options: 'resize_concat', 'multiple_select', None.
|
||||
'resize_concat': Multiple feature maps will be resize to the
|
||||
same size as first one and than concat together.
|
||||
Usually used in FCN head of HRNet.
|
||||
'multiple_select': Multiple feature maps will be bundle into
|
||||
a list and passed into decode head.
|
||||
None: Only one select feature map is allowed.
|
||||
"""
|
||||
|
||||
if input_transform is not None:
|
||||
assert input_transform in ['resize_concat', 'multiple_select']
|
||||
self.input_transform = input_transform
|
||||
self.in_index = in_index
|
||||
if input_transform is not None:
|
||||
assert isinstance(in_channels, (list, tuple))
|
||||
assert isinstance(in_index, (list, tuple))
|
||||
assert len(in_channels) == len(in_index)
|
||||
if input_transform == 'resize_concat':
|
||||
self.in_channels = sum(in_channels)
|
||||
else:
|
||||
self.in_channels = in_channels
|
||||
else:
|
||||
assert isinstance(in_channels, int)
|
||||
assert isinstance(in_index, int)
|
||||
self.in_channels = in_channels
|
||||
|
||||
def init_weights(self):
|
||||
"""Initialize weights of classification layer."""
|
||||
# normal_init(self.conv_seg, mean=0, std=0.01)
|
||||
# normal_init(self.conv1, mean=0, std=0.01)
|
||||
|
||||
def _transform_inputs(self, inputs):
|
||||
"""Transform inputs for decoder.
|
||||
|
||||
Args:
|
||||
inputs (list[Tensor]): List of multi-level img features.
|
||||
|
||||
Returns:
|
||||
Tensor: The transformed inputs
|
||||
"""
|
||||
|
||||
if self.input_transform == 'resize_concat':
|
||||
inputs = [inputs[i] for i in self.in_index]
|
||||
upsampled_inputs = [
|
||||
resize(
|
||||
input=x,
|
||||
size=inputs[0].shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners) for x in inputs
|
||||
]
|
||||
inputs = torch.cat(upsampled_inputs, dim=1)
|
||||
elif self.input_transform == 'multiple_select':
|
||||
inputs = [inputs[i] for i in self.in_index]
|
||||
else:
|
||||
inputs = inputs[self.in_index]
|
||||
|
||||
return inputs
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Placeholder of forward function."""
|
||||
pass
|
||||
|
||||
def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg):
|
||||
"""Forward function for training.
|
||||
Args:
|
||||
inputs (list[Tensor]): List of multi-level img features.
|
||||
img_metas (list[dict]): List of image info dict where each dict
|
||||
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
||||
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
||||
For details on the values of these keys see
|
||||
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
||||
gt_semantic_seg (Tensor): Semantic segmentation masks
|
||||
used if the architecture supports semantic segmentation task.
|
||||
train_cfg (dict): The training config.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components
|
||||
"""
|
||||
seg_logits = self.forward(inputs)
|
||||
losses = self.losses(seg_logits, gt_semantic_seg)
|
||||
return losses
|
||||
|
||||
def forward_test(self, inputs, img_metas, test_cfg):
|
||||
"""Forward function for testing.
|
||||
|
||||
Args:
|
||||
inputs (list[Tensor]): List of multi-level img features.
|
||||
img_metas (list[dict]): List of image info dict where each dict
|
||||
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
||||
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
||||
For details on the values of these keys see
|
||||
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
||||
test_cfg (dict): The testing config.
|
||||
|
||||
Returns:
|
||||
Tensor: Output segmentation map.
|
||||
"""
|
||||
return self.forward(inputs)
|
||||
|
||||
|
||||
class UPerHead(BaseDecodeHead):
|
||||
|
||||
def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
|
||||
super(UPerHead, self).__init__(
|
||||
input_transform='multiple_select', **kwargs)
|
||||
# FPN Module
|
||||
self.lateral_convs = nn.ModuleList()
|
||||
self.fpn_convs = nn.ModuleList()
|
||||
for in_channels in self.in_channels: # skip the top layer
|
||||
l_conv = ConvModule(
|
||||
in_channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
inplace=True)
|
||||
fpn_conv = ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
inplace=True)
|
||||
self.lateral_convs.append(l_conv)
|
||||
self.fpn_convs.append(fpn_conv)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
|
||||
inputs = self._transform_inputs(inputs)
|
||||
|
||||
# build laterals
|
||||
laterals = [
|
||||
lateral_conv(inputs[i])
|
||||
for i, lateral_conv in enumerate(self.lateral_convs)
|
||||
]
|
||||
|
||||
# laterals.append(self.psp_forward(inputs))
|
||||
|
||||
# build top-down path
|
||||
used_backbone_levels = len(laterals)
|
||||
for i in range(used_backbone_levels - 1, 0, -1):
|
||||
prev_shape = laterals[i - 1].shape[2:]
|
||||
laterals[i - 1] += resize(
|
||||
laterals[i],
|
||||
size=prev_shape,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
|
||||
# build outputs
|
||||
fpn_outs = [
|
||||
self.fpn_convs[i](laterals[i])
|
||||
for i in range(used_backbone_levels - 1)
|
||||
]
|
||||
# append psp feature
|
||||
fpn_outs.append(laterals[-1])
|
||||
|
||||
return fpn_outs[0]
|
||||
|
||||
|
||||
class PSP(BaseDecodeHead):
|
||||
"""Unified Perceptual Parsing for Scene Understanding.
|
||||
|
||||
This head is the implementation of `UPerNet
|
||||
<https://arxiv.org/abs/1807.10221>`_.
|
||||
|
||||
Args:
|
||||
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
||||
Module applied on the last feature. Default: (1, 2, 3, 6).
|
||||
"""
|
||||
|
||||
def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
|
||||
super(PSP, self).__init__(input_transform='multiple_select', **kwargs)
|
||||
# PSP Module
|
||||
self.psp_modules = PPM(
|
||||
pool_scales,
|
||||
self.in_channels[-1],
|
||||
self.channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
align_corners=self.align_corners)
|
||||
self.bottleneck = ConvModule(
|
||||
self.in_channels[-1] + len(pool_scales) * self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def psp_forward(self, inputs):
|
||||
"""Forward function of PSP module."""
|
||||
x = inputs[-1]
|
||||
psp_outs = [x]
|
||||
psp_outs.extend(self.psp_modules(x))
|
||||
psp_outs = torch.cat(psp_outs, dim=1)
|
||||
output = self.bottleneck(psp_outs)
|
||||
|
||||
return output
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
inputs = self._transform_inputs(inputs)
|
||||
|
||||
return self.psp_forward(inputs)
|
||||
53
modelscope/models/cv/image_depth_estimation/newcrfs_model.py
Normal file
53
modelscope/models/cv/image_depth_estimation/newcrfs_model.py
Normal file
@@ -0,0 +1,53 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os.path as osp
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base.base_torch_model import TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.models.cv.image_depth_estimation.networks.newcrf_depth import \
|
||||
NewCRFDepth
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.image_depth_estimation, module_name=Models.newcrfs_depth_estimation)
|
||||
class DepthEstimation(TorchModel):
|
||||
|
||||
def __init__(self, model_dir: str, **kwargs):
|
||||
"""str -- model file root."""
|
||||
super().__init__(model_dir, **kwargs)
|
||||
|
||||
# build model
|
||||
self.model = NewCRFDepth(
|
||||
version='large07', inv_depth=False, max_depth=10)
|
||||
|
||||
# load model
|
||||
model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE)
|
||||
checkpoint = torch.load(model_path)
|
||||
|
||||
state_dict = {}
|
||||
for k in checkpoint['model'].keys():
|
||||
if k.startswith('module.'):
|
||||
state_dict[k[7:]] = checkpoint['model'][k]
|
||||
else:
|
||||
state_dict[k] = checkpoint['model'][k]
|
||||
self.model.load_state_dict(state_dict)
|
||||
self.model.eval()
|
||||
|
||||
def forward(self, Inputs):
|
||||
return self.model(Inputs['imgs'])
|
||||
|
||||
def postprocess(self, Inputs):
|
||||
depth_result = Inputs
|
||||
|
||||
results = {OutputKeys.DEPTHS: depth_result}
|
||||
return results
|
||||
|
||||
def inference(self, data):
|
||||
results = self.forward(data)
|
||||
|
||||
return results
|
||||
@@ -509,8 +509,8 @@ def convert_weights(model: nn.Module):
|
||||
@MODELS.register_module(Tasks.multi_modal_embedding, module_name=Models.clip)
|
||||
class CLIPForMultiModalEmbedding(TorchModel):
|
||||
|
||||
def __init__(self, model_dir, device_id=-1):
|
||||
super().__init__(model_dir=model_dir, device_id=device_id)
|
||||
def __init__(self, model_dir, *args, **kwargs):
|
||||
super().__init__(model_dir=model_dir, *args, **kwargs)
|
||||
|
||||
# Initialize the model.
|
||||
vision_model_config_file = '{}/vision_model_config.json'.format(
|
||||
|
||||
@@ -9,7 +9,6 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import Tensor, TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
|
||||
@@ -730,7 +730,7 @@ def make_msa_feat_v2(batch):
|
||||
batch['cluster_profile'],
|
||||
deletion_mean_value,
|
||||
]
|
||||
batch['msa_feat'] = torch.concat(msa_feat, dim=-1)
|
||||
batch['msa_feat'] = torch.cat(msa_feat, dim=-1)
|
||||
return batch
|
||||
|
||||
|
||||
@@ -1320,7 +1320,7 @@ def get_contiguous_crop_idx(
|
||||
asym_offset + this_start + csz))
|
||||
asym_offset += ll
|
||||
|
||||
return torch.concat(crop_idxs)
|
||||
return torch.cat(crop_idxs)
|
||||
|
||||
|
||||
def get_spatial_crop_idx(
|
||||
|
||||
@@ -217,7 +217,7 @@ class MSAAttention(nn.Module):
|
||||
if mask is not None else None)
|
||||
outputs.append(
|
||||
self.mha(q=cur_m, k=cur_m, v=cur_m, mask=cur_mask, bias=bias))
|
||||
return torch.concat(outputs, dim=-3)
|
||||
return torch.cat(outputs, dim=-3)
|
||||
|
||||
def _attn_forward(self, m, mask, bias: Optional[torch.Tensor] = None):
|
||||
m = self.layer_norm_m(m)
|
||||
|
||||
@@ -19,6 +19,7 @@ class OutputKeys(object):
|
||||
BOXES = 'boxes'
|
||||
KEYPOINTS = 'keypoints'
|
||||
MASKS = 'masks'
|
||||
DEPTHS = 'depths'
|
||||
TEXT = 'text'
|
||||
POLYGONS = 'polygons'
|
||||
OUTPUT = 'output'
|
||||
|
||||
@@ -16,7 +16,7 @@ from modelscope.outputs import TASK_OUTPUTS
|
||||
from modelscope.pipeline_inputs import TASK_INPUTS, check_input_type
|
||||
from modelscope.preprocessors import Preprocessor
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import Frameworks, ModelFile
|
||||
from modelscope.utils.constant import Frameworks, Invoke, ModelFile
|
||||
from modelscope.utils.device import (create_device, device_placement,
|
||||
verify_device)
|
||||
from modelscope.utils.hub import read_config, snapshot_download
|
||||
@@ -47,8 +47,10 @@ class Pipeline(ABC):
|
||||
logger.info(f'initiate model from location {model}.')
|
||||
# expecting model has been prefetched to local cache beforehand
|
||||
return Model.from_pretrained(
|
||||
model, model_prefetched=True,
|
||||
device=self.device_name) if is_model(model) else model
|
||||
model,
|
||||
device=self.device_name,
|
||||
model_prefetched=True,
|
||||
invoked_by=Invoke.PIPELINE) if is_model(model) else model
|
||||
else:
|
||||
return model
|
||||
|
||||
@@ -231,7 +233,7 @@ class Pipeline(ABC):
|
||||
batch_data[k] = value_list
|
||||
for k in batch_data.keys():
|
||||
if isinstance(batch_data[k][0], torch.Tensor):
|
||||
batch_data[k] = torch.concat(batch_data[k])
|
||||
batch_data[k] = torch.cat(batch_data[k])
|
||||
return batch_data
|
||||
|
||||
def _process_batch(self, input: List[Input], batch_size,
|
||||
@@ -383,15 +385,12 @@ class DistributedPipeline(Pipeline):
|
||||
preprocessor: Union[Preprocessor, List[Preprocessor]] = None,
|
||||
auto_collate=True,
|
||||
**kwargs):
|
||||
self.preprocessor = preprocessor
|
||||
super().__init__(model=model, preprocessor=preprocessor, kwargs=kwargs)
|
||||
self._model_prepare = False
|
||||
self._model_prepare_lock = Lock()
|
||||
self._auto_collate = auto_collate
|
||||
|
||||
if os.path.exists(model):
|
||||
self.model_dir = model
|
||||
else:
|
||||
self.model_dir = snapshot_download(model)
|
||||
self.model_dir = self.model.model_dir
|
||||
self.cfg = read_config(self.model_dir)
|
||||
self.world_size = self.cfg.model.world_size
|
||||
self.model_pool = None
|
||||
|
||||
@@ -7,7 +7,7 @@ from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.base import Model
|
||||
from modelscope.utils.config import ConfigDict, check_config
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Tasks
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke, Tasks
|
||||
from modelscope.utils.hub import read_config
|
||||
from modelscope.utils.registry import Registry, build_from_cfg
|
||||
from .base import Pipeline
|
||||
@@ -147,6 +147,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
Tasks.image_segmentation:
|
||||
(Pipelines.image_instance_segmentation,
|
||||
'damo/cv_swin-b_image-instance-segmentation_coco'),
|
||||
Tasks.image_depth_estimation:
|
||||
(Pipelines.image_depth_estimation,
|
||||
'damo/cv_newcrfs_image-depth-estimation_indoor'),
|
||||
Tasks.image_style_transfer: (Pipelines.image_style_transfer,
|
||||
'damo/cv_aams_style-transfer_damo'),
|
||||
Tasks.face_image_generation: (Pipelines.face_image_generation,
|
||||
@@ -209,6 +212,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
Tasks.referring_video_object_segmentation:
|
||||
(Pipelines.referring_video_object_segmentation,
|
||||
'damo/cv_swin-t_referring_video-object-segmentation'),
|
||||
Tasks.video_summarization: (Pipelines.video_summarization,
|
||||
'damo/cv_googlenet_pgl-video-summarization'),
|
||||
}
|
||||
|
||||
|
||||
@@ -220,14 +225,19 @@ def normalize_model_input(model, model_revision):
|
||||
# skip revision download if model is a local directory
|
||||
if not os.path.exists(model):
|
||||
# note that if there is already a local copy, snapshot_download will check and skip downloading
|
||||
model = snapshot_download(model, revision=model_revision)
|
||||
model = snapshot_download(
|
||||
model,
|
||||
revision=model_revision,
|
||||
user_agent={Invoke.KEY: Invoke.PIPELINE})
|
||||
elif isinstance(model, list) and isinstance(model[0], str):
|
||||
for idx in range(len(model)):
|
||||
if is_official_hub_path(
|
||||
model[idx],
|
||||
model_revision) and not os.path.exists(model[idx]):
|
||||
model[idx] = snapshot_download(
|
||||
model[idx], revision=model_revision)
|
||||
model[idx],
|
||||
revision=model_revision,
|
||||
user_agent={Invoke.KEY: Invoke.PIPELINE})
|
||||
return model
|
||||
|
||||
|
||||
|
||||
@@ -8,14 +8,13 @@ import torch
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.cv.animal_recognition import Bottleneck, ResNet
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors import LoadImage
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.constant import Devices, ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
@@ -67,15 +66,10 @@ class AnimalRecognitionPipeline(Pipeline):
|
||||
filter_param(src_params, own_state)
|
||||
model.load_state_dict(own_state)
|
||||
|
||||
self.model = resnest101(num_classes=8288)
|
||||
local_model_dir = model
|
||||
if osp.exists(model):
|
||||
local_model_dir = model
|
||||
else:
|
||||
local_model_dir = snapshot_download(model)
|
||||
self.local_path = local_model_dir
|
||||
self.local_path = self.model
|
||||
src_params = torch.load(
|
||||
osp.join(local_model_dir, 'pytorch_model.pt'), 'cpu')
|
||||
osp.join(self.local_path, ModelFile.TORCH_MODEL_FILE), Devices.cpu)
|
||||
self.model = resnest101(num_classes=8288)
|
||||
load_pretrained(self.model, src_params)
|
||||
logger.info('load model done')
|
||||
|
||||
|
||||
@@ -120,8 +120,7 @@ class Body3DKeypointsPipeline(Pipeline):
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
|
||||
self.keypoint_model_3d = model if isinstance(
|
||||
model, BodyKeypointsDetection3D) else Model.from_pretrained(model)
|
||||
self.keypoint_model_3d = self.model
|
||||
self.keypoint_model_3d.eval()
|
||||
|
||||
# init human body 2D keypoints detection pipeline
|
||||
|
||||
@@ -11,7 +11,7 @@ from PIL import ImageFile
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.pipelines.util import is_official_hub_path
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke, ModelFile
|
||||
from modelscope.utils.device import create_device
|
||||
|
||||
|
||||
@@ -37,7 +37,9 @@ class EasyCVPipeline(object):
|
||||
assert is_official_hub_path(
|
||||
model), 'Only support local model path and official hub path!'
|
||||
model_dir = snapshot_download(
|
||||
model_id=model, revision=DEFAULT_MODEL_REVISION)
|
||||
model_id=model,
|
||||
revision=DEFAULT_MODEL_REVISION,
|
||||
user_agent={Invoke.KEY: Invoke.PIPELINE})
|
||||
|
||||
assert osp.isdir(model_dir)
|
||||
model_files = glob.glob(
|
||||
@@ -48,6 +50,7 @@ class EasyCVPipeline(object):
|
||||
|
||||
model_path = model_files[0]
|
||||
self.model_path = model_path
|
||||
self.model_dir = model_dir
|
||||
|
||||
# get configuration file from source model dir
|
||||
self.config_file = os.path.join(model_dir, ModelFile.CONFIGURATION)
|
||||
|
||||
@@ -24,7 +24,6 @@ class HumanWholebodyKeypointsPipeline(EasyCVPipeline):
|
||||
model (str): model id on modelscope hub or local model path.
|
||||
model_file_pattern (str): model file pattern.
|
||||
"""
|
||||
self.model_dir = model
|
||||
super(HumanWholebodyKeypointsPipeline, self).__init__(
|
||||
model=model,
|
||||
model_file_pattern=model_file_pattern,
|
||||
|
||||
@@ -8,7 +8,6 @@ import torch
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.cv.animal_recognition import resnet
|
||||
from modelscope.outputs import OutputKeys
|
||||
@@ -67,16 +66,12 @@ class GeneralRecognitionPipeline(Pipeline):
|
||||
filter_param(src_params, own_state)
|
||||
model.load_state_dict(own_state)
|
||||
|
||||
self.model = resnest101(num_classes=54092)
|
||||
local_model_dir = model
|
||||
device = 'cpu'
|
||||
if osp.exists(model):
|
||||
local_model_dir = model
|
||||
else:
|
||||
local_model_dir = snapshot_download(model)
|
||||
self.local_path = local_model_dir
|
||||
self.local_path = self.model
|
||||
src_params = torch.load(
|
||||
osp.join(local_model_dir, ModelFile.TORCH_MODEL_FILE), device)
|
||||
osp.join(self.local_path, ModelFile.TORCH_MODEL_FILE), device)
|
||||
|
||||
self.model = resnest101(num_classes=54092)
|
||||
load_pretrained(self.model, src_params)
|
||||
logger.info('load model done')
|
||||
|
||||
|
||||
@@ -21,7 +21,6 @@ class Hand2DKeypointsPipeline(EasyCVPipeline):
|
||||
model (str): model id on modelscope hub or local model path.
|
||||
model_file_pattern (str): model file pattern.
|
||||
"""
|
||||
self.model_dir = model
|
||||
super(Hand2DKeypointsPipeline, self).__init__(
|
||||
model=model,
|
||||
model_file_pattern=model_file_pattern,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import Any, Dict, Union
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
@@ -25,22 +25,15 @@ class ImageClassificationPipeline(Pipeline):
|
||||
|
||||
def __init__(self,
|
||||
model: Union[Model, str],
|
||||
preprocessor: [Preprocessor] = None,
|
||||
preprocessor: Optional[Preprocessor] = None,
|
||||
**kwargs):
|
||||
super().__init__(model=model)
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
assert isinstance(model, str) or isinstance(model, Model), \
|
||||
'model must be a single str or OfaForAllTasks'
|
||||
if isinstance(model, str):
|
||||
pipe_model = Model.from_pretrained(model)
|
||||
elif isinstance(model, Model):
|
||||
pipe_model = model
|
||||
else:
|
||||
raise NotImplementedError
|
||||
pipe_model.model.eval()
|
||||
pipe_model.to(get_device())
|
||||
if preprocessor is None and isinstance(pipe_model, OfaForAllTasks):
|
||||
preprocessor = OfaPreprocessor(model_dir=pipe_model.model_dir)
|
||||
super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs)
|
||||
self.model.eval()
|
||||
self.model.to(get_device())
|
||||
if preprocessor is None and isinstance(self.model, OfaForAllTasks):
|
||||
self.preprocessor = OfaPreprocessor(model_dir=self.model.model_dir)
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return inputs
|
||||
|
||||
@@ -32,10 +32,8 @@ class ImageColorEnhancePipeline(Pipeline):
|
||||
Args:
|
||||
model: model id on modelscope hub.
|
||||
"""
|
||||
model = model if isinstance(
|
||||
model, ImageColorEnhance) else Model.from_pretrained(model)
|
||||
model.eval()
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
self.model.eval()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
self._device = torch.device('cuda')
|
||||
|
||||
@@ -32,17 +32,14 @@ class ImageDenoisePipeline(Pipeline):
|
||||
Args:
|
||||
model: model id on modelscope hub.
|
||||
"""
|
||||
model = model if isinstance(
|
||||
model, NAFNetForImageDenoise) else Model.from_pretrained(model)
|
||||
model.eval()
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
self.config = model.config
|
||||
self.model.eval()
|
||||
self.config = self.model.config
|
||||
|
||||
if torch.cuda.is_available():
|
||||
self._device = torch.device('cuda')
|
||||
else:
|
||||
self._device = torch.device('cpu')
|
||||
self.model = model
|
||||
logger.info('load image denoise model done')
|
||||
|
||||
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||
|
||||
52
modelscope/pipelines/cv/image_depth_estimation_pipeline.py
Normal file
52
modelscope/pipelines/cv/image_depth_estimation_pipeline.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Model, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors import LoadImage
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.image_depth_estimation, module_name=Pipelines.image_depth_estimation)
|
||||
class ImageDepthEstimationPipeline(Pipeline):
|
||||
|
||||
def __init__(self, model: str, **kwargs):
|
||||
"""
|
||||
use `model` to create a image depth estimation pipeline for prediction
|
||||
Args:
|
||||
model: model id on modelscope hub.
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
|
||||
logger.info('depth estimation model, pipeline init')
|
||||
|
||||
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||
img = LoadImage.convert_to_ndarray(input).astype(np.float32)
|
||||
H, W = 480, 640
|
||||
img = cv2.resize(img, [W, H])
|
||||
img = img.transpose(2, 0, 1) / 255.0
|
||||
imgs = img[None, ...]
|
||||
data = {'imgs': imgs}
|
||||
|
||||
return data
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
results = self.model.inference(input)
|
||||
return results
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
results = self.model.postprocess(inputs)
|
||||
outputs = {OutputKeys.DEPTHS: results[OutputKeys.DEPTHS]}
|
||||
|
||||
return outputs
|
||||
@@ -44,7 +44,7 @@ class LanguageGuidedVideoSummarizationPipeline(Pipeline):
|
||||
"""
|
||||
super().__init__(model=model, auto_collate=False, **kwargs)
|
||||
logger.info(f'loading model from {model}')
|
||||
self.model_dir = model
|
||||
self.model_dir = self.model.model_dir
|
||||
|
||||
self.tmp_dir = kwargs.get('tmp_dir', None)
|
||||
if self.tmp_dir is None:
|
||||
|
||||
@@ -9,7 +9,6 @@ import PIL
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.cv.virual_tryon import SDAFNet_Tryon
|
||||
from modelscope.outputs import OutputKeys
|
||||
@@ -52,17 +51,12 @@ class VirtualTryonPipeline(Pipeline):
|
||||
filter_param(src_params, own_state)
|
||||
model.load_state_dict(own_state)
|
||||
|
||||
self.model = SDAFNet_Tryon(ref_in_channel=6).to(self.device)
|
||||
local_model_dir = model
|
||||
if osp.exists(model):
|
||||
local_model_dir = model
|
||||
else:
|
||||
local_model_dir = snapshot_download(model)
|
||||
self.local_path = local_model_dir
|
||||
self.local_path = self.model
|
||||
src_params = torch.load(
|
||||
osp.join(local_model_dir, ModelFile.TORCH_MODEL_FILE), 'cpu')
|
||||
osp.join(self.local_path, ModelFile.TORCH_MODEL_FILE), 'cpu')
|
||||
self.model = SDAFNet_Tryon(ref_in_channel=6).to(self.device)
|
||||
load_pretrained(self.model, src_params)
|
||||
self.model = self.model.eval()
|
||||
self.model.eval()
|
||||
self.size = 192
|
||||
from torchvision import transforms
|
||||
self.test_transforms = transforms.Compose([
|
||||
|
||||
@@ -29,22 +29,13 @@ class ImageCaptioningPipeline(Pipeline):
|
||||
Args:
|
||||
model: model id on modelscope hub.
|
||||
"""
|
||||
super().__init__(model=model)
|
||||
assert isinstance(model, str) or isinstance(model, Model), \
|
||||
'model must be a single str or OfaForAllTasks'
|
||||
if isinstance(model, str):
|
||||
pipe_model = Model.from_pretrained(model)
|
||||
elif isinstance(model, Model):
|
||||
pipe_model = model
|
||||
else:
|
||||
raise NotImplementedError
|
||||
pipe_model.model.eval()
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
self.model.eval()
|
||||
if preprocessor is None:
|
||||
if isinstance(pipe_model, OfaForAllTasks):
|
||||
preprocessor = OfaPreprocessor(pipe_model.model_dir)
|
||||
elif isinstance(pipe_model, MPlugForAllTasks):
|
||||
preprocessor = MPlugPreprocessor(pipe_model.model_dir)
|
||||
super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs)
|
||||
if isinstance(self.model, OfaForAllTasks):
|
||||
self.preprocessor = OfaPreprocessor(self.model.model_dir)
|
||||
elif isinstance(self.model, MPlugForAllTasks):
|
||||
self.preprocessor = MPlugPreprocessor(self.model.model_dir)
|
||||
|
||||
def _batch(self, data):
|
||||
if isinstance(self.model, OfaForAllTasks):
|
||||
@@ -55,17 +46,17 @@ class ImageCaptioningPipeline(Pipeline):
|
||||
batch_data['samples'] = [d['samples'][0] for d in data]
|
||||
batch_data['net_input'] = {}
|
||||
for k in data[0]['net_input'].keys():
|
||||
batch_data['net_input'][k] = torch.concat(
|
||||
batch_data['net_input'][k] = torch.cat(
|
||||
[d['net_input'][k] for d in data])
|
||||
|
||||
return batch_data
|
||||
elif isinstance(self.model, MPlugForAllTasks):
|
||||
from transformers.tokenization_utils_base import BatchEncoding
|
||||
batch_data = dict(train=data[0]['train'])
|
||||
batch_data['image'] = torch.concat([d['image'] for d in data])
|
||||
batch_data['image'] = torch.cat([d['image'] for d in data])
|
||||
question = {}
|
||||
for k in data[0]['question'].keys():
|
||||
question[k] = torch.concat([d['question'][k] for d in data])
|
||||
question[k] = torch.cat([d['question'][k] for d in data])
|
||||
batch_data['question'] = BatchEncoding(question)
|
||||
return batch_data
|
||||
else:
|
||||
|
||||
@@ -28,19 +28,10 @@ class ImageTextRetrievalPipeline(Pipeline):
|
||||
Args:
|
||||
model: model id on modelscope hub.
|
||||
"""
|
||||
super().__init__(model=model)
|
||||
assert isinstance(model, str) or isinstance(model, Model), \
|
||||
f'model must be a single str or Model, but got {type(model)}'
|
||||
if isinstance(model, str):
|
||||
pipe_model = Model.from_pretrained(model)
|
||||
elif isinstance(model, Model):
|
||||
pipe_model = model
|
||||
else:
|
||||
raise NotImplementedError
|
||||
pipe_model.model.eval()
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
self.model.eval()
|
||||
if preprocessor is None:
|
||||
preprocessor = MPlugPreprocessor(pipe_model.model_dir)
|
||||
super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs)
|
||||
self.preprocessor = MPlugPreprocessor(self.model.model_dir)
|
||||
|
||||
def forward(self, inputs: Dict[str, Any],
|
||||
**forward_params) -> Dict[str, Any]:
|
||||
|
||||
@@ -28,21 +28,14 @@ class MultiModalEmbeddingPipeline(Pipeline):
|
||||
Args:
|
||||
model: model id on modelscope hub.
|
||||
"""
|
||||
if isinstance(model, str):
|
||||
pipe_model = Model.from_pretrained(model)
|
||||
elif isinstance(model, Model):
|
||||
pipe_model = model
|
||||
else:
|
||||
raise NotImplementedError('model must be a single str')
|
||||
pipe_model.eval()
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
self.model.eval()
|
||||
if preprocessor is None:
|
||||
if isinstance(pipe_model, CLIPForMultiModalEmbedding):
|
||||
preprocessor = CLIPPreprocessor(pipe_model.model_dir)
|
||||
if isinstance(self.model, CLIPForMultiModalEmbedding):
|
||||
self.preprocessor = CLIPPreprocessor(self.model.model_dir)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs)
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return self.model(self.preprocess(input))
|
||||
|
||||
|
||||
@@ -28,20 +28,11 @@ class OcrRecognitionPipeline(Pipeline):
|
||||
Args:
|
||||
model: model id on modelscope hub.
|
||||
"""
|
||||
super().__init__(model=model)
|
||||
assert isinstance(model, str) or isinstance(model, Model), \
|
||||
'model must be a single str or OfaForAllTasks'
|
||||
if isinstance(model, str):
|
||||
pipe_model = Model.from_pretrained(model)
|
||||
elif isinstance(model, Model):
|
||||
pipe_model = model
|
||||
else:
|
||||
raise NotImplementedError
|
||||
pipe_model.model.eval()
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
self.model.eval()
|
||||
if preprocessor is None:
|
||||
if isinstance(pipe_model, OfaForAllTasks):
|
||||
preprocessor = OfaPreprocessor(pipe_model.model_dir)
|
||||
super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs)
|
||||
if isinstance(self.model, OfaForAllTasks):
|
||||
self.preprocessor = OfaPreprocessor(self.model.model_dir)
|
||||
|
||||
def forward(self, inputs: Dict[str, Any],
|
||||
**forward_params) -> Dict[str, Any]:
|
||||
|
||||
@@ -31,18 +31,10 @@ class TextToImageSynthesisPipeline(Pipeline):
|
||||
Args:
|
||||
model: model id on modelscope hub.
|
||||
"""
|
||||
device_id = 0 if torch.cuda.is_available() else -1
|
||||
if isinstance(model, str):
|
||||
pipe_model = Model.from_pretrained(model, device_id=device_id)
|
||||
elif isinstance(model, Model):
|
||||
pipe_model = model
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'expecting a Model instance or str, but get {type(model)}.')
|
||||
if preprocessor is None and isinstance(pipe_model,
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
if preprocessor is None and isinstance(self.model,
|
||||
OfaForTextToImageSynthesis):
|
||||
preprocessor = OfaPreprocessor(pipe_model.model_dir)
|
||||
super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs)
|
||||
self.preprocessor = OfaPreprocessor(self.model.model_dir)
|
||||
|
||||
def preprocess(self, input: Input, **preprocess_params) -> Dict[str, Any]:
|
||||
if self.preprocessor is not None:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import Any, Dict, Union
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.multi_modal import OfaForAllTasks
|
||||
@@ -18,26 +18,17 @@ class VisualEntailmentPipeline(Pipeline):
|
||||
|
||||
def __init__(self,
|
||||
model: Union[Model, str],
|
||||
preprocessor: [Preprocessor] = None,
|
||||
preprocessor: Optional[Preprocessor] = None,
|
||||
**kwargs):
|
||||
"""
|
||||
use `model` and `preprocessor` to create a visual entailment pipeline for prediction
|
||||
Args:
|
||||
model: model id on modelscope hub.
|
||||
"""
|
||||
super().__init__(model=model)
|
||||
assert isinstance(model, str) or isinstance(model, Model), \
|
||||
'model must be a single str or OfaForAllTasks'
|
||||
if isinstance(model, str):
|
||||
pipe_model = Model.from_pretrained(model)
|
||||
elif isinstance(model, Model):
|
||||
pipe_model = model
|
||||
else:
|
||||
raise NotImplementedError
|
||||
pipe_model.model.eval()
|
||||
if preprocessor is None and isinstance(pipe_model, OfaForAllTasks):
|
||||
preprocessor = OfaPreprocessor(model_dir=pipe_model.model_dir)
|
||||
super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs)
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
self.model.eval()
|
||||
if preprocessor is None and isinstance(self.model, OfaForAllTasks):
|
||||
self.preprocessor = OfaPreprocessor(model_dir=self.model.model_dir)
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return inputs
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import Any, Dict, Union
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.multi_modal import OfaForAllTasks
|
||||
@@ -18,26 +18,17 @@ class VisualGroundingPipeline(Pipeline):
|
||||
|
||||
def __init__(self,
|
||||
model: Union[Model, str],
|
||||
preprocessor: [Preprocessor] = None,
|
||||
preprocessor: Optional[Preprocessor] = None,
|
||||
**kwargs):
|
||||
"""
|
||||
use `model` and `preprocessor` to create a visual grounding pipeline for prediction
|
||||
Args:
|
||||
model: model id on modelscope hub.
|
||||
"""
|
||||
super().__init__(model=model)
|
||||
assert isinstance(model, str) or isinstance(model, Model), \
|
||||
'model must be a single str or OfaForAllTasks'
|
||||
if isinstance(model, str):
|
||||
pipe_model = Model.from_pretrained(model)
|
||||
elif isinstance(model, Model):
|
||||
pipe_model = model
|
||||
else:
|
||||
raise NotImplementedError
|
||||
pipe_model.model.eval()
|
||||
if preprocessor is None and isinstance(pipe_model, OfaForAllTasks):
|
||||
preprocessor = OfaPreprocessor(model_dir=pipe_model.model_dir)
|
||||
super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs)
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
self.model.model.eval()
|
||||
if preprocessor is None and isinstance(self.model, OfaForAllTasks):
|
||||
self.preprocessor = OfaPreprocessor(model_dir=self.model.model_dir)
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return inputs
|
||||
|
||||
@@ -31,15 +31,13 @@ class VisualQuestionAnsweringPipeline(Pipeline):
|
||||
model (MPlugForVisualQuestionAnswering): a model instance
|
||||
preprocessor (MPlugVisualQuestionAnsweringPreprocessor): a preprocessor instance
|
||||
"""
|
||||
model = model if isinstance(model,
|
||||
Model) else Model.from_pretrained(model)
|
||||
if preprocessor is None:
|
||||
if isinstance(model, OfaForAllTasks):
|
||||
preprocessor = OfaPreprocessor(model.model_dir)
|
||||
elif isinstance(model, MPlugForAllTasks):
|
||||
preprocessor = MPlugPreprocessor(model.model_dir)
|
||||
model.model.eval()
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
if preprocessor is None:
|
||||
if isinstance(self.model, OfaForAllTasks):
|
||||
self.preprocessor = OfaPreprocessor(self.model.model_dir)
|
||||
elif isinstance(self.model, MPlugForAllTasks):
|
||||
self.preprocessor = MPlugPreprocessor(self.model.model_dir)
|
||||
self.model.eval()
|
||||
|
||||
def forward(self, inputs: Dict[str, Any],
|
||||
**forward_params) -> Dict[str, Any]:
|
||||
|
||||
@@ -32,12 +32,10 @@ class ConversationalTextToSqlPipeline(Pipeline):
|
||||
preprocessor (ConversationalTextToSqlPreprocessor):
|
||||
a preprocessor instance
|
||||
"""
|
||||
model = model if isinstance(
|
||||
model, StarForTextToSql) else Model.from_pretrained(model)
|
||||
if preprocessor is None:
|
||||
preprocessor = ConversationalTextToSqlPreprocessor(model.model_dir)
|
||||
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
if preprocessor is None:
|
||||
self.preprocessor = ConversationalTextToSqlPreprocessor(
|
||||
self.model.model_dir)
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""process the prediction results
|
||||
|
||||
@@ -30,13 +30,11 @@ class DialogIntentPredictionPipeline(Pipeline):
|
||||
or a SpaceForDialogIntent instance.
|
||||
preprocessor (DialogIntentPredictionPreprocessor): An optional preprocessor instance.
|
||||
"""
|
||||
model = model if isinstance(
|
||||
model, SpaceForDialogIntent) else Model.from_pretrained(model)
|
||||
if preprocessor is None:
|
||||
preprocessor = DialogIntentPredictionPreprocessor(model.model_dir)
|
||||
self.model = model
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
self.categories = preprocessor.categories
|
||||
if preprocessor is None:
|
||||
self.preprocessor = DialogIntentPredictionPreprocessor(
|
||||
self.model.model_dir)
|
||||
self.categories = self.preprocessor.categories
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""process the prediction results
|
||||
|
||||
@@ -29,13 +29,10 @@ class DialogModelingPipeline(Pipeline):
|
||||
or a SpaceForDialogModeling instance.
|
||||
preprocessor (DialogModelingPreprocessor): An optional preprocessor instance.
|
||||
"""
|
||||
model = model if isinstance(
|
||||
model, SpaceForDialogModeling) else Model.from_pretrained(model)
|
||||
self.model = model
|
||||
if preprocessor is None:
|
||||
preprocessor = DialogModelingPreprocessor(model.model_dir)
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
self.preprocessor = preprocessor
|
||||
if preprocessor is None:
|
||||
self.preprocessor = DialogModelingPreprocessor(
|
||||
self.model.model_dir)
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, str]:
|
||||
"""process the prediction results
|
||||
|
||||
@@ -31,16 +31,13 @@ class DialogStateTrackingPipeline(Pipeline):
|
||||
from the model hub, or a SpaceForDialogStateTracking instance.
|
||||
preprocessor (DialogStateTrackingPreprocessor): An optional preprocessor instance.
|
||||
"""
|
||||
|
||||
model = model if isinstance(
|
||||
model, SpaceForDST) else Model.from_pretrained(model)
|
||||
self.model = model
|
||||
if preprocessor is None:
|
||||
preprocessor = DialogStateTrackingPreprocessor(model.model_dir)
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
if preprocessor is None:
|
||||
self.preprocessor = DialogStateTrackingPreprocessor(
|
||||
self.model.model_dir)
|
||||
|
||||
self.tokenizer = preprocessor.tokenizer
|
||||
self.config = preprocessor.config
|
||||
self.tokenizer = self.preprocessor.tokenizer
|
||||
self.config = self.preprocessor.config
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""process the prediction results
|
||||
|
||||
@@ -31,27 +31,22 @@ class DocumentSegmentationPipeline(Pipeline):
|
||||
model: Union[Model, str],
|
||||
preprocessor: DocumentSegmentationPreprocessor = None,
|
||||
**kwargs):
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
|
||||
model = model if isinstance(model,
|
||||
Model) else Model.from_pretrained(model)
|
||||
|
||||
self.model_dir = model.model_dir
|
||||
self.model_cfg = model.forward()
|
||||
self.model_dir = self.model.model_dir
|
||||
self.model_cfg = self.model.forward()
|
||||
|
||||
if self.model_cfg['type'] == 'bert':
|
||||
config = BertConfig.from_pretrained(model.model_dir, num_labels=2)
|
||||
config = BertConfig.from_pretrained(self.model_dir, num_labels=2)
|
||||
elif self.model_cfg['type'] == 'ponet':
|
||||
config = PoNetConfig.from_pretrained(model.model_dir, num_labels=2)
|
||||
config = PoNetConfig.from_pretrained(self.model_dir, num_labels=2)
|
||||
|
||||
self.document_segmentation_model = model.build_with_config(
|
||||
self.document_segmentation_model = self.model.build_with_config(
|
||||
config=config)
|
||||
|
||||
if preprocessor is None:
|
||||
preprocessor = DocumentSegmentationPreprocessor(
|
||||
self.model_dir, config)
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
|
||||
self.preprocessor = preprocessor
|
||||
self.preprocessor = DocumentSegmentationPreprocessor(
|
||||
self.model.model_dir, config)
|
||||
|
||||
def __call__(
|
||||
self, documents: Union[List[List[str]], List[str],
|
||||
|
||||
@@ -21,12 +21,10 @@ class FaqQuestionAnsweringPipeline(Pipeline):
|
||||
model: Union[str, Model],
|
||||
preprocessor: Preprocessor = None,
|
||||
**kwargs):
|
||||
model = Model.from_pretrained(model) if isinstance(model,
|
||||
str) else model
|
||||
if preprocessor is None:
|
||||
preprocessor = Preprocessor.from_pretrained(
|
||||
model.model_dir, **kwargs)
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
if preprocessor is None:
|
||||
self.preprocessor = Preprocessor.from_pretrained(
|
||||
self.model.model_dir, **kwargs)
|
||||
|
||||
def _sanitize_parameters(self, **pipeline_parameters):
|
||||
return pipeline_parameters, pipeline_parameters, pipeline_parameters
|
||||
@@ -37,11 +35,11 @@ class FaqQuestionAnsweringPipeline(Pipeline):
|
||||
sentence_vecs = sentence_vecs.detach().tolist()
|
||||
return sentence_vecs
|
||||
|
||||
def forward(self, inputs: [list, Dict[str, Any]],
|
||||
def forward(self, inputs: Union[list, Dict[str, Any]],
|
||||
**forward_params) -> Dict[str, Any]:
|
||||
return self.model(inputs)
|
||||
|
||||
def postprocess(self, inputs: [list, Dict[str, Any]],
|
||||
def postprocess(self, inputs: Union[list, Dict[str, Any]],
|
||||
**postprocess_params) -> Dict[str, Any]:
|
||||
scores = inputs['scores']
|
||||
labels = []
|
||||
|
||||
@@ -46,21 +46,18 @@ class FeatureExtractionPipeline(Pipeline):
|
||||
|
||||
|
||||
"""
|
||||
model = model if isinstance(model,
|
||||
Model) else Model.from_pretrained(model)
|
||||
|
||||
if preprocessor is None:
|
||||
preprocessor = NLPPreprocessor(
|
||||
model.model_dir,
|
||||
padding=kwargs.pop('padding', False),
|
||||
sequence_length=kwargs.pop('sequence_length', 128))
|
||||
model.eval()
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
|
||||
self.preprocessor = preprocessor
|
||||
if preprocessor is None:
|
||||
self.preprocessor = NLPPreprocessor(
|
||||
self.model.model_dir,
|
||||
padding=kwargs.pop('padding', False),
|
||||
sequence_length=kwargs.pop('sequence_length', 128))
|
||||
self.model.eval()
|
||||
|
||||
self.config = Config.from_file(
|
||||
os.path.join(model.model_dir, ModelFile.CONFIGURATION))
|
||||
self.tokenizer = preprocessor.tokenizer
|
||||
os.path.join(self.model.model_dir, ModelFile.CONFIGURATION))
|
||||
self.tokenizer = self.preprocessor.tokenizer
|
||||
|
||||
def forward(self, inputs: Dict[str, Any],
|
||||
**forward_params) -> Dict[str, Any]:
|
||||
|
||||
@@ -53,22 +53,18 @@ class FillMaskPipeline(Pipeline):
|
||||
If the xlm-roberta(xlm-roberta, veco, etc.) based model is used, the mask token is '<mask>'.
|
||||
To view other examples plese check the tests/pipelines/test_fill_mask.py.
|
||||
"""
|
||||
|
||||
fill_mask_model = Model.from_pretrained(model) if isinstance(
|
||||
model, str) else model
|
||||
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
if preprocessor is None:
|
||||
preprocessor = Preprocessor.from_pretrained(
|
||||
fill_mask_model.model_dir,
|
||||
self.preprocessor = Preprocessor.from_pretrained(
|
||||
self.model.model_dir,
|
||||
first_sequence=first_sequence,
|
||||
second_sequence=None,
|
||||
sequence_length=kwargs.pop('sequence_length', 128))
|
||||
fill_mask_model.eval()
|
||||
assert hasattr(
|
||||
preprocessor, 'mask_id'
|
||||
), 'The input preprocessor should have the mask_id attribute.'
|
||||
super().__init__(
|
||||
model=fill_mask_model, preprocessor=preprocessor, **kwargs)
|
||||
assert hasattr(
|
||||
self.preprocessor, 'mask_id'
|
||||
), 'The input preprocessor should have the mask_id attribute.'
|
||||
|
||||
self.model.eval()
|
||||
|
||||
def forward(self, inputs: Dict[str, Any],
|
||||
**forward_params) -> Dict[str, Any]:
|
||||
|
||||
@@ -25,15 +25,12 @@ class InformationExtractionPipeline(Pipeline):
|
||||
model: Union[Model, str],
|
||||
preprocessor: Optional[Preprocessor] = None,
|
||||
**kwargs):
|
||||
|
||||
model = model if isinstance(model,
|
||||
Model) else Model.from_pretrained(model)
|
||||
if preprocessor is None:
|
||||
preprocessor = RelationExtractionPreprocessor(
|
||||
model.model_dir,
|
||||
sequence_length=kwargs.pop('sequence_length', 512))
|
||||
model.eval()
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
if preprocessor is None:
|
||||
self.preprocessor = RelationExtractionPreprocessor(
|
||||
self.model.model_dir,
|
||||
sequence_length=kwargs.pop('sequence_length', 512))
|
||||
self.model.eval()
|
||||
|
||||
def forward(self, inputs: Dict[str, Any],
|
||||
**forward_params) -> Dict[str, Any]:
|
||||
|
||||
@@ -21,7 +21,7 @@ class MGLMTextSummarizationPipeline(Pipeline):
|
||||
|
||||
def __init__(self,
|
||||
model: Union[MGLMForTextSummarization, str],
|
||||
preprocessor: [Preprocessor] = None,
|
||||
preprocessor: Optional[Preprocessor] = None,
|
||||
*args,
|
||||
**kwargs):
|
||||
model = MGLMForTextSummarization(model) if isinstance(model,
|
||||
|
||||
@@ -50,15 +50,12 @@ class NamedEntityRecognitionPipeline(TokenClassificationPipeline):
|
||||
|
||||
To view other examples plese check the tests/pipelines/test_named_entity_recognition.py.
|
||||
"""
|
||||
|
||||
model = model if isinstance(model,
|
||||
Model) else Model.from_pretrained(model)
|
||||
if preprocessor is None:
|
||||
preprocessor = TokenClassificationPreprocessor(
|
||||
model.model_dir,
|
||||
sequence_length=kwargs.pop('sequence_length', 128))
|
||||
model.eval()
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
if preprocessor is None:
|
||||
self.preprocessor = TokenClassificationPreprocessor(
|
||||
self.model.model_dir,
|
||||
sequence_length=kwargs.pop('sequence_length', 128))
|
||||
self.model.eval()
|
||||
self.id2label = kwargs.get('id2label')
|
||||
if self.id2label is None and hasattr(self.preprocessor, 'id2label'):
|
||||
self.id2label = self.preprocessor.id2label
|
||||
@@ -73,13 +70,11 @@ class NamedEntityRecognitionThaiPipeline(NamedEntityRecognitionPipeline):
|
||||
model: Union[Model, str],
|
||||
preprocessor: Optional[Preprocessor] = None,
|
||||
**kwargs):
|
||||
model = model if isinstance(model,
|
||||
Model) else Model.from_pretrained(model)
|
||||
if preprocessor is None:
|
||||
preprocessor = NERPreprocessorThai(
|
||||
model.model_dir,
|
||||
sequence_length=kwargs.pop('sequence_length', 512))
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
if preprocessor is None:
|
||||
self.preprocessor = NERPreprocessorThai(
|
||||
self.model.model_dir,
|
||||
sequence_length=kwargs.pop('sequence_length', 512))
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
@@ -91,10 +86,8 @@ class NamedEntityRecognitionVietPipeline(NamedEntityRecognitionPipeline):
|
||||
model: Union[Model, str],
|
||||
preprocessor: Optional[Preprocessor] = None,
|
||||
**kwargs):
|
||||
model = model if isinstance(model,
|
||||
Model) else Model.from_pretrained(model)
|
||||
if preprocessor is None:
|
||||
preprocessor = NERPreprocessorViet(
|
||||
model.model_dir,
|
||||
sequence_length=kwargs.pop('sequence_length', 512))
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
if preprocessor is None:
|
||||
self.preprocessor = NERPreprocessorViet(
|
||||
self.model.model_dir,
|
||||
sequence_length=kwargs.pop('sequence_length', 512))
|
||||
|
||||
@@ -32,14 +32,13 @@ class SentenceEmbeddingPipeline(Pipeline):
|
||||
the model if supplied.
|
||||
sequence_length: Max sequence length in the user's custom scenario. 128 will be used as a default value.
|
||||
"""
|
||||
model = Model.from_pretrained(model) if isinstance(model,
|
||||
str) else model
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
if preprocessor is None:
|
||||
preprocessor = Preprocessor.from_pretrained(
|
||||
model.model_dir if isinstance(model, Model) else model,
|
||||
self.preprocessor = Preprocessor.from_pretrained(
|
||||
self.model.model_dir
|
||||
if isinstance(self.model, Model) else model,
|
||||
first_sequence=first_sequence,
|
||||
sequence_length=kwargs.pop('sequence_length', 128))
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
|
||||
def forward(self, inputs: Dict[str, Any],
|
||||
**forward_params) -> Dict[str, Any]:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import Any, Dict, Union
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.multi_modal import OfaForAllTasks
|
||||
@@ -18,7 +18,7 @@ class SummarizationPipeline(Pipeline):
|
||||
|
||||
def __init__(self,
|
||||
model: Union[Model, str],
|
||||
preprocessor: [Preprocessor] = None,
|
||||
preprocessor: Optional[Preprocessor] = None,
|
||||
**kwargs):
|
||||
"""Use `model` and `preprocessor` to create a Summarization pipeline for prediction.
|
||||
|
||||
@@ -27,19 +27,10 @@ class SummarizationPipeline(Pipeline):
|
||||
or a model id from the model hub, or a model instance.
|
||||
preprocessor (Preprocessor): An optional preprocessor instance.
|
||||
"""
|
||||
super().__init__(model=model)
|
||||
assert isinstance(model, str) or isinstance(model, Model), \
|
||||
'model must be a single str or OfaForAllTasks'
|
||||
if isinstance(model, str):
|
||||
pipe_model = Model.from_pretrained(model)
|
||||
elif isinstance(model, Model):
|
||||
pipe_model = model
|
||||
else:
|
||||
raise NotImplementedError
|
||||
pipe_model.model.eval()
|
||||
if preprocessor is None and isinstance(pipe_model, OfaForAllTasks):
|
||||
preprocessor = OfaPreprocessor(model_dir=pipe_model.model_dir)
|
||||
super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs)
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
self.model.eval()
|
||||
if preprocessor is None and isinstance(self.model, OfaForAllTasks):
|
||||
self.preprocessor = OfaPreprocessor(model_dir=self.model.model_dir)
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return inputs
|
||||
|
||||
@@ -41,21 +41,22 @@ class TableQuestionAnsweringPipeline(Pipeline):
|
||||
preprocessor (TableQuestionAnsweringPreprocessor): a preprocessor instance
|
||||
db (Database): a database to store tables in the database
|
||||
"""
|
||||
model = model if isinstance(
|
||||
model, TableQuestionAnswering) else Model.from_pretrained(model)
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
if preprocessor is None:
|
||||
preprocessor = TableQuestionAnsweringPreprocessor(model.model_dir)
|
||||
self.preprocessor = TableQuestionAnsweringPreprocessor(
|
||||
self.model.model_dir)
|
||||
|
||||
# initilize tokenizer
|
||||
self.tokenizer = BertTokenizer(
|
||||
os.path.join(model.model_dir, ModelFile.VOCAB_FILE))
|
||||
os.path.join(self.model.model_dir, ModelFile.VOCAB_FILE))
|
||||
|
||||
# initialize database
|
||||
if db is None:
|
||||
self.db = Database(
|
||||
tokenizer=self.tokenizer,
|
||||
table_file_path=os.path.join(model.model_dir, 'table.json'),
|
||||
syn_dict_file_path=os.path.join(model.model_dir,
|
||||
table_file_path=os.path.join(self.model.model_dir,
|
||||
'table.json'),
|
||||
syn_dict_file_path=os.path.join(self.model.model_dir,
|
||||
'synonym.txt'))
|
||||
else:
|
||||
self.db = db
|
||||
@@ -71,8 +72,6 @@ class TableQuestionAnsweringPipeline(Pipeline):
|
||||
self.schema_link_dict = constant.schema_link_dict
|
||||
self.limit_dict = constant.limit_dict
|
||||
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
|
||||
def post_process_multi_turn(self, history_sql, result, table):
|
||||
action = self.action_ops[result['action']]
|
||||
headers = table['header_name']
|
||||
|
||||
@@ -63,16 +63,14 @@ class Text2TextGenerationPipeline(Pipeline):
|
||||
|
||||
To view other examples plese check the tests/pipelines/test_text_generation.py.
|
||||
"""
|
||||
model = model if isinstance(model,
|
||||
Model) else Model.from_pretrained(model)
|
||||
if preprocessor is None:
|
||||
preprocessor = Text2TextGenerationPreprocessor(
|
||||
model.model_dir,
|
||||
sequence_length=kwargs.pop('sequence_length', 128))
|
||||
self.tokenizer = preprocessor.tokenizer
|
||||
self.pipeline = model.pipeline.type
|
||||
model.eval()
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
if preprocessor is None:
|
||||
self.preprocessor = Text2TextGenerationPreprocessor(
|
||||
self.model.model_dir,
|
||||
sequence_length=kwargs.pop('sequence_length', 128))
|
||||
self.tokenizer = self.preprocessor.tokenizer
|
||||
self.pipeline = self.model.pipeline.type
|
||||
self.model.eval()
|
||||
|
||||
def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]:
|
||||
""" Provide specific preprocess for text2text generation pipeline in order to handl multi tasks
|
||||
|
||||
@@ -53,25 +53,24 @@ class TextClassificationPipeline(Pipeline):
|
||||
NOTE: Inputs of type 'str' are also supported. In this scenario, the 'first_sequence' and 'second_sequence'
|
||||
param will have no affection.
|
||||
"""
|
||||
model = Model.from_pretrained(model) if isinstance(model,
|
||||
str) else model
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
|
||||
if preprocessor is None:
|
||||
if model.__class__.__name__ == 'OfaForAllTasks':
|
||||
preprocessor = Preprocessor.from_pretrained(
|
||||
model_name_or_path=model.model_dir,
|
||||
if self.model.__class__.__name__ == 'OfaForAllTasks':
|
||||
self.preprocessor = Preprocessor.from_pretrained(
|
||||
model_name_or_path=self.model.model_dir,
|
||||
type=Preprocessors.ofa_tasks_preprocessor,
|
||||
field=Fields.multi_modal)
|
||||
else:
|
||||
first_sequence = kwargs.pop('first_sequence', 'first_sequence')
|
||||
second_sequence = kwargs.pop('second_sequence', None)
|
||||
preprocessor = Preprocessor.from_pretrained(
|
||||
model if isinstance(model, str) else model.model_dir,
|
||||
self.preprocessor = Preprocessor.from_pretrained(
|
||||
self.model
|
||||
if isinstance(self.model, str) else self.model.model_dir,
|
||||
first_sequence=first_sequence,
|
||||
second_sequence=second_sequence,
|
||||
sequence_length=kwargs.pop('sequence_length', 512))
|
||||
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
self.id2label = kwargs.get('id2label')
|
||||
if self.id2label is None and hasattr(self.preprocessor, 'id2label'):
|
||||
self.id2label = self.preprocessor.id2label
|
||||
|
||||
@@ -40,15 +40,13 @@ class TextErrorCorrectionPipeline(Pipeline):
|
||||
|
||||
To view other examples plese check the tests/pipelines/test_text_error_correction.py.
|
||||
"""
|
||||
|
||||
model = model if isinstance(
|
||||
model,
|
||||
BartForTextErrorCorrection) else Model.from_pretrained(model)
|
||||
if preprocessor is None:
|
||||
preprocessor = TextErrorCorrectionPreprocessor(model.model_dir)
|
||||
self.vocab = preprocessor.vocab
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
|
||||
if preprocessor is None:
|
||||
self.preprocessor = TextErrorCorrectionPreprocessor(
|
||||
self.model.model_dir)
|
||||
self.vocab = self.preprocessor.vocab
|
||||
|
||||
def forward(self, inputs: Dict[str, Any],
|
||||
**forward_params) -> Dict[str, Any]:
|
||||
with torch.no_grad():
|
||||
|
||||
@@ -51,15 +51,14 @@ class TextGenerationPipeline(Pipeline):
|
||||
|
||||
To view other examples plese check the tests/pipelines/test_text_generation.py.
|
||||
"""
|
||||
model = model if isinstance(model,
|
||||
Model) else Model.from_pretrained(model)
|
||||
cfg = read_config(model.model_dir)
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
cfg = read_config(self.model.model_dir)
|
||||
self.postprocessor = cfg.pop('postprocessor', 'decode')
|
||||
if preprocessor is None:
|
||||
preprocessor_cfg = cfg.preprocessor
|
||||
preprocessor_cfg.update({
|
||||
'model_dir':
|
||||
model.model_dir,
|
||||
self.model.model_dir,
|
||||
'first_sequence':
|
||||
first_sequence,
|
||||
'second_sequence':
|
||||
@@ -67,9 +66,9 @@ class TextGenerationPipeline(Pipeline):
|
||||
'sequence_length':
|
||||
kwargs.pop('sequence_length', 128)
|
||||
})
|
||||
preprocessor = build_preprocessor(preprocessor_cfg, Fields.nlp)
|
||||
model.eval()
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
self.preprocessor = build_preprocessor(preprocessor_cfg,
|
||||
Fields.nlp)
|
||||
self.model.eval()
|
||||
|
||||
def _sanitize_parameters(self, **pipeline_parameters):
|
||||
return {}, pipeline_parameters, {}
|
||||
|
||||
@@ -32,14 +32,12 @@ class TextRankingPipeline(Pipeline):
|
||||
the model if supplied.
|
||||
sequence_length: Max sequence length in the user's custom scenario. 128 will be used as a default value.
|
||||
"""
|
||||
model = Model.from_pretrained(model) if isinstance(model,
|
||||
str) else model
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
|
||||
if preprocessor is None:
|
||||
preprocessor = Preprocessor.from_pretrained(
|
||||
model.model_dir,
|
||||
self.preprocessor = Preprocessor.from_pretrained(
|
||||
self.model.model_dir,
|
||||
sequence_length=kwargs.pop('sequence_length', 128))
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
|
||||
def forward(self, inputs: Dict[str, Any],
|
||||
**forward_params) -> Dict[str, Any]:
|
||||
|
||||
@@ -39,15 +39,14 @@ class TokenClassificationPipeline(Pipeline):
|
||||
model (str or Model): A model instance or a model local dir or a model id in the model hub.
|
||||
preprocessor (Preprocessor): a preprocessor instance, must not be None.
|
||||
"""
|
||||
model = Model.from_pretrained(model) if isinstance(model,
|
||||
str) else model
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
|
||||
if preprocessor is None:
|
||||
preprocessor = Preprocessor.from_pretrained(
|
||||
model.model_dir,
|
||||
self.preprocessor = Preprocessor.from_pretrained(
|
||||
self.model.model_dir,
|
||||
sequence_length=kwargs.pop('sequence_length', 128))
|
||||
model.eval()
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
self.model.eval()
|
||||
|
||||
self.id2label = kwargs.get('id2label')
|
||||
if self.id2label is None and hasattr(self.preprocessor, 'id2label'):
|
||||
self.id2label = self.preprocessor.id2label
|
||||
|
||||
@@ -27,10 +27,10 @@ class TranslationQualityEstimationPipeline(Pipeline):
|
||||
|
||||
def __init__(self, model: str, device: str = 'gpu', **kwargs):
|
||||
super().__init__(model=model, device=device)
|
||||
model_file = os.path.join(model, ModelFile.TORCH_MODEL_FILE)
|
||||
model_file = os.path.join(self.model, ModelFile.TORCH_MODEL_FILE)
|
||||
with open(model_file, 'rb') as f:
|
||||
buffer = io.BytesIO(f.read())
|
||||
self.tokenizer = XLMRobertaTokenizer.from_pretrained(model)
|
||||
self.tokenizer = XLMRobertaTokenizer.from_pretrained(self.model)
|
||||
self.model = torch.jit.load(
|
||||
buffer, map_location=self.device).to(self.device)
|
||||
|
||||
|
||||
@@ -49,14 +49,13 @@ class WordSegmentationPipeline(TokenClassificationPipeline):
|
||||
|
||||
To view other examples plese check the tests/pipelines/test_word_segmentation.py.
|
||||
"""
|
||||
model = model if isinstance(model,
|
||||
Model) else Model.from_pretrained(model)
|
||||
if preprocessor is None:
|
||||
preprocessor = TokenClassificationPreprocessor(
|
||||
model.model_dir,
|
||||
sequence_length=kwargs.pop('sequence_length', 128))
|
||||
model.eval()
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
if preprocessor is None:
|
||||
self.preprocessor = TokenClassificationPreprocessor(
|
||||
self.model.model_dir,
|
||||
sequence_length=kwargs.pop('sequence_length', 128))
|
||||
self.model.eval()
|
||||
|
||||
self.id2label = kwargs.get('id2label')
|
||||
if self.id2label is None and hasattr(self.preprocessor, 'id2label'):
|
||||
self.id2label = self.preprocessor.id2label
|
||||
|
||||
@@ -59,16 +59,14 @@ class ZeroShotClassificationPipeline(Pipeline):
|
||||
"""
|
||||
assert isinstance(model, str) or isinstance(model, Model), \
|
||||
'model must be a single str or Model'
|
||||
model = model if isinstance(model,
|
||||
Model) else Model.from_pretrained(model)
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
self.entailment_id = 0
|
||||
self.contradiction_id = 2
|
||||
if preprocessor is None:
|
||||
preprocessor = ZeroShotClassificationPreprocessor(
|
||||
model.model_dir,
|
||||
self.preprocessor = ZeroShotClassificationPreprocessor(
|
||||
self.model.model_dir,
|
||||
sequence_length=kwargs.pop('sequence_length', 512))
|
||||
model.eval()
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
self.model.eval()
|
||||
|
||||
def _sanitize_parameters(self, **kwargs):
|
||||
preprocess_params = {}
|
||||
|
||||
@@ -105,22 +105,16 @@ class ProteinStructurePipeline(Pipeline):
|
||||
>>> print(pipeline_ins(protein))
|
||||
|
||||
"""
|
||||
import copy
|
||||
model_path = copy.deepcopy(model) if isinstance(model, str) else None
|
||||
cfg = read_config(model_path) # only model is str
|
||||
self.cfg = cfg
|
||||
self.config = model_config(
|
||||
cfg['pipeline']['model_name']) # alphafold config
|
||||
model = model if isinstance(
|
||||
model, Model) else Model.from_pretrained(model_path)
|
||||
self.postprocessor = cfg.pop('postprocessor', None)
|
||||
if preprocessor is None:
|
||||
preprocessor_cfg = cfg.preprocessor
|
||||
preprocessor = build_preprocessor(preprocessor_cfg, Fields.science)
|
||||
model.eval()
|
||||
model.model.inference_mode()
|
||||
model.model_dir = model_path
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
self.cfg = read_config(self.model.model_dir)
|
||||
self.config = model_config(
|
||||
self.cfg['pipeline']['model_name']) # alphafold config
|
||||
self.postprocessor = self.cfg.pop('postprocessor', None)
|
||||
if preprocessor is None:
|
||||
preprocessor_cfg = self.cfg.preprocessor
|
||||
self.preprocessor = build_preprocessor(preprocessor_cfg,
|
||||
Fields.science)
|
||||
self.model.eval()
|
||||
|
||||
def _sanitize_parameters(self, **pipeline_parameters):
|
||||
return pipeline_parameters, pipeline_parameters, pipeline_parameters
|
||||
|
||||
@@ -6,7 +6,8 @@ from typing import Any, Dict, Optional, Sequence
|
||||
|
||||
from modelscope.metainfo import Models, Preprocessors
|
||||
from modelscope.utils.config import Config, ConfigDict
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModeKeys, Tasks
|
||||
from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, Invoke,
|
||||
ModeKeys, Tasks)
|
||||
from modelscope.utils.hub import read_config, snapshot_download
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .builder import build_preprocessor
|
||||
@@ -194,7 +195,9 @@ class Preprocessor(ABC):
|
||||
"""
|
||||
if not os.path.exists(model_name_or_path):
|
||||
model_dir = snapshot_download(
|
||||
model_name_or_path, revision=revision)
|
||||
model_name_or_path,
|
||||
revision=revision,
|
||||
user_agent={Invoke.KEY: Invoke.PREPROCESSOR})
|
||||
else:
|
||||
model_dir = model_name_or_path
|
||||
if cfg_dict is None:
|
||||
|
||||
@@ -14,7 +14,8 @@ from modelscope.metainfo import Preprocessors
|
||||
from modelscope.pipelines.base import Input
|
||||
from modelscope.preprocessors import load_image
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import Fields, ModeKeys, ModelFile, Tasks
|
||||
from modelscope.utils.constant import (Fields, Invoke, ModeKeys, ModelFile,
|
||||
Tasks)
|
||||
from .base import Preprocessor
|
||||
from .builder import PREPROCESSORS
|
||||
from .ofa import * # noqa
|
||||
@@ -57,7 +58,7 @@ class OfaPreprocessor(Preprocessor):
|
||||
Tasks.auto_speech_recognition: OfaASRPreprocessor
|
||||
}
|
||||
model_dir = model_dir if osp.exists(model_dir) else snapshot_download(
|
||||
model_dir)
|
||||
model_dir, user_agent={Invoke.KEY: Invoke.PREPROCESSOR})
|
||||
self.cfg = Config.from_file(
|
||||
osp.join(model_dir, ModelFile.CONFIGURATION))
|
||||
self.preprocess = preprocess_mapping[self.cfg.task](
|
||||
@@ -131,7 +132,7 @@ class CLIPPreprocessor(Preprocessor):
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
model_dir = model_dir if osp.exists(model_dir) else snapshot_download(
|
||||
model_dir)
|
||||
model_dir, user_agent={Invoke.KEY: Invoke.PREPROCESSOR})
|
||||
self.mode = mode
|
||||
# text tokenizer
|
||||
from modelscope.models.multi_modal.clip.bert_tokenizer import FullTokenizer
|
||||
|
||||
@@ -5,6 +5,7 @@ import random
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import librosa
|
||||
import soundfile as sf
|
||||
import torch
|
||||
from fairseq.data.audio.feature_transforms import \
|
||||
@@ -54,9 +55,13 @@ class OfaASRPreprocessor(OfaBasePreprocessor):
|
||||
|
||||
def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
speed = random.choice([0.9, 1.0, 1.1])
|
||||
wav, sr = sf.read(self.column_map['wav'])
|
||||
wav, sr = librosa.load(data[self.column_map['wav']], 16000, mono=True)
|
||||
fbank = self.prepare_fbank(
|
||||
torch.tensor([wav], dtype=torch.float32), sr, speed, is_train=True)
|
||||
torch.tensor([wav], dtype=torch.float32),
|
||||
sr,
|
||||
speed,
|
||||
target_sample_rate=16000,
|
||||
is_train=True)
|
||||
fbank_mask = torch.tensor([True])
|
||||
sample = {
|
||||
'fbank': fbank,
|
||||
@@ -86,11 +91,12 @@ class OfaASRPreprocessor(OfaBasePreprocessor):
|
||||
|
||||
def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
speed = 1.0
|
||||
wav, sr = sf.read(data[self.column_map['wav']])
|
||||
wav, sr = librosa.load(data[self.column_map['wav']], 16000, mono=True)
|
||||
fbank = self.prepare_fbank(
|
||||
torch.tensor([wav], dtype=torch.float32),
|
||||
sr,
|
||||
speed,
|
||||
target_sample_rate=16000,
|
||||
is_train=False)
|
||||
fbank_mask = torch.tensor([True])
|
||||
|
||||
|
||||
@@ -170,10 +170,15 @@ class OfaBasePreprocessor:
|
||||
else load_image(path_or_url_or_pil)
|
||||
return image
|
||||
|
||||
def prepare_fbank(self, waveform, sample_rate, speed, is_train):
|
||||
waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
|
||||
def prepare_fbank(self,
|
||||
waveform,
|
||||
sample_rate,
|
||||
speed,
|
||||
target_sample_rate=16000,
|
||||
is_train=False):
|
||||
waveform, sample_rate = torchaudio.sox_effects.apply_effects_tensor(
|
||||
waveform, sample_rate,
|
||||
[['speed', str(speed)], ['rate', str(sample_rate)]])
|
||||
[['speed', str(speed)], ['rate', str(target_sample_rate)]])
|
||||
_waveform, _ = convert_waveform(
|
||||
waveform, sample_rate, to_mono=True, normalize_volume=True)
|
||||
# Kaldi compliance: 16-bit signed integers
|
||||
|
||||
@@ -8,7 +8,6 @@ import torch
|
||||
from torch import nn as nn
|
||||
from torch import optim as optim
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.models import Model, TorchModel
|
||||
from modelscope.msdatasets.task_datasets.audio import KWSDataLoader, KWSDataset
|
||||
@@ -54,12 +53,8 @@ class KWSFarfieldTrainer(BaseTrainer):
|
||||
**kwargs):
|
||||
|
||||
if isinstance(model, str):
|
||||
if os.path.exists(model):
|
||||
self.model_dir = model if os.path.isdir(
|
||||
model) else os.path.dirname(model)
|
||||
else:
|
||||
self.model_dir = snapshot_download(
|
||||
model, revision=model_revision)
|
||||
self.model_dir = self.get_or_download_model_dir(
|
||||
model, model_revision)
|
||||
if cfg_file is None:
|
||||
cfg_file = os.path.join(self.model_dir,
|
||||
ModelFile.CONFIGURATION)
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.trainers.builder import TRAINERS
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import Invoke
|
||||
from .utils.log_buffer import LogBuffer
|
||||
|
||||
|
||||
@@ -32,6 +35,17 @@ class BaseTrainer(ABC):
|
||||
self.log_buffer = LogBuffer()
|
||||
self.timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
||||
|
||||
def get_or_download_model_dir(self, model, model_revision=None):
|
||||
if os.path.exists(model):
|
||||
model_cache_dir = model if os.path.isdir(
|
||||
model) else os.path.dirname(model)
|
||||
else:
|
||||
model_cache_dir = snapshot_download(
|
||||
model,
|
||||
revision=model_revision,
|
||||
user_agent={Invoke.KEY: Invoke.TRAINER})
|
||||
return model_cache_dir
|
||||
|
||||
@abstractmethod
|
||||
def train(self, *args, **kwargs):
|
||||
""" Train (and evaluate) process
|
||||
|
||||
@@ -20,7 +20,7 @@ from modelscope.trainers.builder import TRAINERS
|
||||
from modelscope.trainers.optimizer.builder import build_optimizer
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigKeys,
|
||||
ModeKeys)
|
||||
Invoke, ModeKeys)
|
||||
from .clip_trainer_utils import get_loss, get_optimizer_params, get_schedule
|
||||
|
||||
|
||||
@@ -52,7 +52,8 @@ class CLIPTrainer(EpochBasedTrainer):
|
||||
model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
|
||||
seed: int = 42,
|
||||
**kwargs):
|
||||
model = Model.from_pretrained(model, revision=model_revision)
|
||||
model = Model.from_pretrained(
|
||||
model, revision=model_revision, invoked_by=Invoke.TRAINER)
|
||||
# for training & eval, we convert the model from FP16 back to FP32
|
||||
# to compatible with modelscope amp training
|
||||
convert_models_to_fp32(model)
|
||||
|
||||
@@ -23,7 +23,7 @@ from modelscope.trainers.optimizer.builder import build_optimizer
|
||||
from modelscope.trainers.parallel.utils import is_parallel
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigKeys,
|
||||
ModeKeys)
|
||||
Invoke, ModeKeys)
|
||||
from .ofa_trainer_utils import (AdjustLabelSmoothedCrossEntropyCriterion,
|
||||
get_schedule)
|
||||
|
||||
@@ -49,7 +49,8 @@ class OFATrainer(EpochBasedTrainer):
|
||||
model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
|
||||
seed: int = 42,
|
||||
**kwargs):
|
||||
model = Model.from_pretrained(model, revision=model_revision)
|
||||
model = Model.from_pretrained(
|
||||
model, revision=model_revision, invoked_by=Invoke.TRAINER)
|
||||
model_dir = model.model_dir
|
||||
self.cfg_modify_fn = cfg_modify_fn
|
||||
cfg = self.rebuild_config(Config.from_file(cfg_file))
|
||||
|
||||
@@ -7,21 +7,17 @@ from typing import Callable, Dict, Optional
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.datasets as datasets
|
||||
import torchvision.transforms as transforms
|
||||
from sklearn.metrics import confusion_matrix
|
||||
from torch.optim import AdamW
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.models.base import Model
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.trainers.base import BaseTrainer
|
||||
from modelscope.trainers.builder import TRAINERS
|
||||
from modelscope.trainers.multi_modal.team.team_trainer_utils import (
|
||||
get_optimizer, train_mapping, val_mapping)
|
||||
from modelscope.trainers.multi_modal.team.team_trainer_utils import \
|
||||
get_optimizer
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import DownloadMode, ModeKeys
|
||||
from modelscope.utils.constant import Invoke
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
@@ -36,7 +32,7 @@ class TEAMImgClsTrainer(BaseTrainer):
|
||||
super().__init__(cfg_file)
|
||||
|
||||
self.cfg = Config.from_file(cfg_file)
|
||||
team_model = Model.from_pretrained(model)
|
||||
team_model = Model.from_pretrained(model, invoked_by=Invoke.TRAINER)
|
||||
image_model = team_model.model.image_model.vision_transformer
|
||||
classification_model = nn.Sequential(
|
||||
OrderedDict([('encoder', image_model),
|
||||
|
||||
@@ -24,8 +24,7 @@ logger = get_logger()
|
||||
class CsanmtTranslationTrainer(BaseTrainer):
|
||||
|
||||
def __init__(self, model: str, cfg_file: str = None, *args, **kwargs):
|
||||
if not osp.exists(model):
|
||||
model = snapshot_download(model)
|
||||
model = self.get_or_download_model_dir(model)
|
||||
tf.reset_default_graph()
|
||||
|
||||
self.model_dir = model
|
||||
|
||||
@@ -10,7 +10,6 @@ import torch
|
||||
from torch import nn
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.metrics.builder import build_metric
|
||||
from modelscope.models.base import Model, TorchModel
|
||||
@@ -478,11 +477,7 @@ class NlpEpochBasedTrainer(EpochBasedTrainer):
|
||||
"""
|
||||
|
||||
if isinstance(model, str):
|
||||
if os.path.exists(model):
|
||||
model_dir = model if os.path.isdir(model) else os.path.dirname(
|
||||
model)
|
||||
else:
|
||||
model_dir = snapshot_download(model, revision=model_revision)
|
||||
model_dir = self.get_or_download_model_dir(model, model_revision)
|
||||
if cfg_file is None:
|
||||
cfg_file = os.path.join(model_dir, ModelFile.CONFIGURATION)
|
||||
else:
|
||||
|
||||
@@ -14,7 +14,6 @@ from torch.utils.data import DataLoader, Dataset
|
||||
from torch.utils.data.dataloader import default_collate
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.metrics import build_metric, task_default_metrics
|
||||
from modelscope.models.base import Model, TorchModel
|
||||
@@ -98,12 +97,8 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
self._seed = seed
|
||||
set_random_seed(self._seed)
|
||||
if isinstance(model, str):
|
||||
if os.path.exists(model):
|
||||
self.model_dir = model if os.path.isdir(
|
||||
model) else os.path.dirname(model)
|
||||
else:
|
||||
self.model_dir = snapshot_download(
|
||||
model, revision=model_revision)
|
||||
self.model_dir = self.get_or_download_model_dir(
|
||||
model, model_revision)
|
||||
if cfg_file is None:
|
||||
cfg_file = os.path.join(self.model_dir,
|
||||
ModelFile.CONFIGURATION)
|
||||
|
||||
@@ -44,6 +44,7 @@ class CVTasks(object):
|
||||
|
||||
image_segmentation = 'image-segmentation'
|
||||
semantic_segmentation = 'semantic-segmentation'
|
||||
image_depth_estimation = 'image-depth-estimation'
|
||||
portrait_matting = 'portrait-matting'
|
||||
text_driven_segmentation = 'text-driven-segmentation'
|
||||
shop_segmentation = 'shop-segmentation'
|
||||
@@ -293,6 +294,14 @@ class ModelFile(object):
|
||||
TS_MODEL_FILE = 'model.ts'
|
||||
|
||||
|
||||
class Invoke(object):
|
||||
KEY = 'invoked_by'
|
||||
PRETRAINED = 'from_pretrained'
|
||||
PIPELINE = 'pipeline'
|
||||
TRAINER = 'trainer'
|
||||
PREPROCESSOR = 'preprocessor'
|
||||
|
||||
|
||||
class ConfigFields(object):
|
||||
""" First level keyword in configuration file
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
from modelscope.outputs import OutputKeys
|
||||
@@ -439,3 +440,11 @@ def show_image_object_detection_auto_result(img_path,
|
||||
if save_path is not None:
|
||||
cv2.imwrite(save_path, img)
|
||||
return img
|
||||
|
||||
|
||||
def depth_to_color(depth):
|
||||
colormap = plt.get_cmap('plasma')
|
||||
depth_color = (colormap(
|
||||
(depth.max() - depth) / depth.max()) * 2**8).astype(np.uint8)[:, :, :3]
|
||||
depth_color = cv2.cvtColor(depth_color, cv2.COLOR_RGB2BGR)
|
||||
return depth_color
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
ftfy>=6.0.3
|
||||
librosa
|
||||
ofa>=0.0.2
|
||||
pycocoevalcap>=1.2
|
||||
pycocotools>=2.0.4
|
||||
|
||||
@@ -28,7 +28,7 @@ class ExtractiveSummarizationTest(unittest.TestCase, DemoCompatibilityCheck):
|
||||
result = p(documents=documents)
|
||||
return result
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run_with_doc(self):
|
||||
logger.info(
|
||||
'Run doc extractive summarization (PoNet) with one document ...')
|
||||
@@ -37,7 +37,7 @@ class ExtractiveSummarizationTest(unittest.TestCase, DemoCompatibilityCheck):
|
||||
model_id=self.ponet_doc_model_id, documents=self.sentences)
|
||||
print(result[OutputKeys.TEXT])
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run_with_topic(self):
|
||||
logger.info(
|
||||
'Run topic extractive summarization (PoNet) with one document ...')
|
||||
|
||||
35
tests/pipelines/test_image_depth_estimation.py
Normal file
35
tests/pipelines/test_image_depth_estimation.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import unittest
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.cv.image_utils import depth_to_color
|
||||
from modelscope.utils.demo_utils import DemoCompatibilityCheck
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class ImageDepthEstimationTest(unittest.TestCase, DemoCompatibilityCheck):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.task = 'image-depth-estimation'
|
||||
self.model_id = 'damo/cv_newcrfs_image-depth-estimation_indoor'
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_image_depth_estimation(self):
|
||||
input_location = 'data/test/images/image_depth_estimation.jpg'
|
||||
estimator = pipeline(Tasks.image_depth_estimation, model=self.model_id)
|
||||
result = estimator(input_location)
|
||||
depths = result[OutputKeys.DEPTHS]
|
||||
depth_viz = depth_to_color(depths[0].squeeze().cpu().numpy())
|
||||
cv2.imwrite('result.jpg', depth_viz)
|
||||
|
||||
print('test_image_depth_estimation DONE')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user