[to #42322933]add damoyolo model in tinynas-object-detection

接入damyolo系列检测模型
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10377688
This commit is contained in:
xianzhe.xxz
2022-10-18 16:53:29 +08:00
committed by yingda.chen
parent 6438c41144
commit 865397763e
11 changed files with 126 additions and 44 deletions

View File

@@ -9,7 +9,9 @@ class Models(object):
Model name should only contain model info but not task info. Model name should only contain model info but not task info.
""" """
# tinynas models
tinynas_detection = 'tinynas-detection' tinynas_detection = 'tinynas-detection'
tinynas_damoyolo = 'tinynas-damoyolo'
# vision models # vision models
detection = 'detection' detection = 'detection'

View File

@@ -7,10 +7,12 @@ from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING: if TYPE_CHECKING:
from .tinynas_detector import Tinynas_detector from .tinynas_detector import Tinynas_detector
from .tinynas_damoyolo import DamoYolo
else: else:
_import_structure = { _import_structure = {
'tinynas_detector': ['TinynasDetector'], 'tinynas_detector': ['TinynasDetector'],
'tinynas_damoyolo': ['DamoYolo'],
} }
import sys import sys

View File

@@ -4,6 +4,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from modelscope.utils.file_utils import read_file
from ..core.base_ops import Focus, SPPBottleneck, get_activation from ..core.base_ops import Focus, SPPBottleneck, get_activation
from ..core.repvgg_block import RepVggBlock from ..core.repvgg_block import RepVggBlock
@@ -49,12 +50,16 @@ class ResConvK1KX(nn.Module):
kernel_size, kernel_size,
stride, stride,
force_resproj=False, force_resproj=False,
act='silu'): act='silu',
reparam=False):
super(ResConvK1KX, self).__init__() super(ResConvK1KX, self).__init__()
self.stride = stride self.stride = stride
self.conv1 = ConvKXBN(in_c, btn_c, 1, 1) self.conv1 = ConvKXBN(in_c, btn_c, 1, 1)
self.conv2 = RepVggBlock( if not reparam:
btn_c, out_c, kernel_size, stride, act='identity') self.conv2 = ConvKXBN(btn_c, out_c, 3, stride)
else:
self.conv2 = RepVggBlock(
btn_c, out_c, kernel_size, stride, act='identity')
if act is None: if act is None:
self.activation_function = torch.relu self.activation_function = torch.relu
@@ -97,7 +102,8 @@ class SuperResConvK1KX(nn.Module):
stride, stride,
num_blocks, num_blocks,
with_spp=False, with_spp=False,
act='silu'): act='silu',
reparam=False):
super(SuperResConvK1KX, self).__init__() super(SuperResConvK1KX, self).__init__()
if act is None: if act is None:
self.act = torch.relu self.act = torch.relu
@@ -124,7 +130,8 @@ class SuperResConvK1KX(nn.Module):
this_kernel_size, this_kernel_size,
this_stride, this_stride,
force_resproj, force_resproj,
act=act) act=act,
reparam=reparam)
self.block_list.append(the_block) self.block_list.append(the_block)
if block_id == 0 and with_spp: if block_id == 0 and with_spp:
self.block_list.append( self.block_list.append(
@@ -248,7 +255,8 @@ class TinyNAS(nn.Module):
with_spp=False, with_spp=False,
use_focus=False, use_focus=False,
need_conv1=True, need_conv1=True,
act='silu'): act='silu',
reparam=False):
super(TinyNAS, self).__init__() super(TinyNAS, self).__init__()
assert len(out_indices) == len(out_channels) assert len(out_indices) == len(out_channels)
self.out_indices = out_indices self.out_indices = out_indices
@@ -281,7 +289,8 @@ class TinyNAS(nn.Module):
block_info['s'], block_info['s'],
block_info['L'], block_info['L'],
spp, spp,
act=act) act=act,
reparam=reparam)
self.block_list.append(the_block) self.block_list.append(the_block)
elif the_block_class == 'SuperResConvKXKX': elif the_block_class == 'SuperResConvKXKX':
spp = with_spp if idx == len(structure_info) - 1 else False spp = with_spp if idx == len(structure_info) - 1 else False
@@ -325,8 +334,8 @@ class TinyNAS(nn.Module):
def load_tinynas_net(backbone_cfg): def load_tinynas_net(backbone_cfg):
# load masternet model to path # load masternet model to path
import ast import ast
net_structure_str = read_file(backbone_cfg.structure_file)
struct_str = ''.join([x.strip() for x in backbone_cfg.net_structure_str]) struct_str = ''.join([x.strip() for x in net_structure_str])
struct_info = ast.literal_eval(struct_str) struct_info = ast.literal_eval(struct_str)
for layer in struct_info: for layer in struct_info:
if 'nbitsA' in layer: if 'nbitsA' in layer:
@@ -342,6 +351,6 @@ def load_tinynas_net(backbone_cfg):
use_focus=backbone_cfg.use_focus, use_focus=backbone_cfg.use_focus,
act=backbone_cfg.act, act=backbone_cfg.act,
need_conv1=backbone_cfg.need_conv1, need_conv1=backbone_cfg.need_conv1,
) reparam=backbone_cfg.reparam)
return model return model

View File

@@ -30,7 +30,7 @@ class SingleStageDetector(TorchModel):
""" """
super().__init__(model_dir, *args, **kwargs) super().__init__(model_dir, *args, **kwargs)
config_path = osp.join(model_dir, 'airdet_s.py') config_path = osp.join(model_dir, self.config_name)
config = parse_config(config_path) config = parse_config(config_path)
self.cfg = config self.cfg = config
model_path = osp.join(model_dir, config.model.name) model_path = osp.join(model_dir, config.model.name)
@@ -41,6 +41,9 @@ class SingleStageDetector(TorchModel):
self.conf_thre = config.model.head.nms_conf_thre self.conf_thre = config.model.head.nms_conf_thre
self.nms_thre = config.model.head.nms_iou_thre self.nms_thre = config.model.head.nms_iou_thre
if self.cfg.model.backbone.name == 'TinyNAS':
self.cfg.model.backbone.structure_file = osp.join(
model_dir, self.cfg.model.backbone.structure_file)
self.backbone = build_backbone(self.cfg.model.backbone) self.backbone = build_backbone(self.cfg.model.backbone)
self.neck = build_neck(self.cfg.model.neck) self.neck = build_neck(self.cfg.model.neck)
self.head = build_head(self.cfg.model.head) self.head = build_head(self.cfg.model.head)

View File

@@ -124,11 +124,13 @@ class GFocalHead_Tiny(nn.Module):
simOTA_iou_weight=3.0, simOTA_iou_weight=3.0,
octbase=8, octbase=8,
simlqe=False, simlqe=False,
use_lqe=True,
**kwargs): **kwargs):
self.simlqe = simlqe self.simlqe = simlqe
self.num_classes = num_classes self.num_classes = num_classes
self.in_channels = in_channels self.in_channels = in_channels
self.strides = strides self.strides = strides
self.use_lqe = use_lqe
self.feat_channels = feat_channels if isinstance(feat_channels, list) \ self.feat_channels = feat_channels if isinstance(feat_channels, list) \
else [feat_channels] * len(self.strides) else [feat_channels] * len(self.strides)
@@ -181,15 +183,20 @@ class GFocalHead_Tiny(nn.Module):
groups=self.conv_groups, groups=self.conv_groups,
norm=self.norm, norm=self.norm,
act=self.act)) act=self.act))
if not self.simlqe: if self.use_lqe:
conf_vector = [nn.Conv2d(4 * self.total_dim, self.reg_channels, 1)] if not self.simlqe:
conf_vector = [
nn.Conv2d(4 * self.total_dim, self.reg_channels, 1)
]
else:
conf_vector = [
nn.Conv2d(4 * (self.reg_max + 1), self.reg_channels, 1)
]
conf_vector += [self.relu]
conf_vector += [nn.Conv2d(self.reg_channels, 1, 1), nn.Sigmoid()]
reg_conf = nn.Sequential(*conf_vector)
else: else:
conf_vector = [ reg_conf = None
nn.Conv2d(4 * (self.reg_max + 1), self.reg_channels, 1)
]
conf_vector += [self.relu]
conf_vector += [nn.Conv2d(self.reg_channels, 1, 1), nn.Sigmoid()]
reg_conf = nn.Sequential(*conf_vector)
return cls_convs, reg_convs, reg_conf return cls_convs, reg_convs, reg_conf
@@ -290,21 +297,27 @@ class GFocalHead_Tiny(nn.Module):
N, C, H, W = bbox_pred.size() N, C, H, W = bbox_pred.size()
prob = F.softmax( prob = F.softmax(
bbox_pred.reshape(N, 4, self.reg_max + 1, H, W), dim=2) bbox_pred.reshape(N, 4, self.reg_max + 1, H, W), dim=2)
if not self.simlqe: if self.use_lqe:
prob_topk, _ = prob.topk(self.reg_topk, dim=2) if not self.simlqe:
prob_topk, _ = prob.topk(self.reg_topk, dim=2)
if self.add_mean: if self.add_mean:
stat = torch.cat( stat = torch.cat(
[prob_topk, prob_topk.mean(dim=2, keepdim=True)], dim=2) [prob_topk,
prob_topk.mean(dim=2, keepdim=True)],
dim=2)
else:
stat = prob_topk
quality_score = reg_conf(
stat.reshape(N, 4 * self.total_dim, H, W))
else: else:
stat = prob_topk quality_score = reg_conf(
bbox_pred.reshape(N, 4 * (self.reg_max + 1), H, W))
quality_score = reg_conf(stat.reshape(N, 4 * self.total_dim, H, W)) cls_score = gfl_cls(cls_feat).sigmoid() * quality_score
else: else:
quality_score = reg_conf( cls_score = gfl_cls(cls_feat).sigmoid()
bbox_pred.reshape(N, 4 * (self.reg_max + 1), H, W))
cls_score = gfl_cls(cls_feat).sigmoid() * quality_score
flatten_cls_score = cls_score.flatten(start_dim=2).transpose(1, 2) flatten_cls_score = cls_score.flatten(start_dim=2).transpose(1, 2)
flatten_bbox_pred = bbox_pred.flatten(start_dim=2).transpose(1, 2) flatten_bbox_pred = bbox_pred.flatten(start_dim=2).transpose(1, 2)

View File

@@ -14,7 +14,6 @@ class GiraffeNeckV2(nn.Module):
self, self,
depth=1.0, depth=1.0,
width=1.0, width=1.0,
in_features=[2, 3, 4],
in_channels=[256, 512, 1024], in_channels=[256, 512, 1024],
out_channels=[256, 512, 1024], out_channels=[256, 512, 1024],
depthwise=False, depthwise=False,
@@ -24,7 +23,6 @@ class GiraffeNeckV2(nn.Module):
block_name='BasicBlock', block_name='BasicBlock',
): ):
super().__init__() super().__init__()
self.in_features = in_features
self.in_channels = in_channels self.in_channels = in_channels
Conv = DWConv if depthwise else BaseConv Conv = DWConv if depthwise else BaseConv
@@ -169,8 +167,7 @@ class GiraffeNeckV2(nn.Module):
""" """
# backbone # backbone
features = [out_features[f] for f in self.in_features] [x2, x1, x0] = out_features
[x2, x1, x0] = features
# node x3 # node x3
x13 = self.bu_conv13(x1) x13 = self.bu_conv13(x1)

View File

@@ -0,0 +1,15 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from modelscope.metainfo import Models
from modelscope.models.builder import MODELS
from modelscope.utils.constant import Tasks
from .detector import SingleStageDetector
@MODELS.register_module(
Tasks.image_object_detection, module_name=Models.tinynas_damoyolo)
class DamoYolo(SingleStageDetector):
def __init__(self, model_dir, *args, **kwargs):
self.config_name = 'damoyolo_s.py'
super(DamoYolo, self).__init__(model_dir, *args, **kwargs)

View File

@@ -12,5 +12,5 @@ from .detector import SingleStageDetector
class TinynasDetector(SingleStageDetector): class TinynasDetector(SingleStageDetector):
def __init__(self, model_dir, *args, **kwargs): def __init__(self, model_dir, *args, **kwargs):
self.config_name = 'airdet_s.py'
super(TinynasDetector, self).__init__(model_dir, *args, **kwargs) super(TinynasDetector, self).__init__(model_dir, *args, **kwargs)

View File

@@ -12,6 +12,8 @@ from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import LoadImage from modelscope.preprocessors import LoadImage
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks
from modelscope.utils.cv.image_utils import \
show_image_object_detection_auto_result
from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger
logger = get_logger() logger = get_logger()
@@ -52,10 +54,18 @@ class TinynasDetectionPipeline(Pipeline):
bboxes, scores, labels = self.model.postprocess(inputs['data']) bboxes, scores, labels = self.model.postprocess(inputs['data'])
if bboxes is None: if bboxes is None:
return None outputs = {
outputs = { OutputKeys.SCORES: [],
OutputKeys.SCORES: scores, OutputKeys.LABELS: [],
OutputKeys.LABELS: labels, OutputKeys.BOXES: []
OutputKeys.BOXES: bboxes }
} else:
outputs = {
OutputKeys.SCORES: scores,
OutputKeys.LABELS: labels,
OutputKeys.BOXES: bboxes
}
return outputs return outputs
def show_result(self, img_path, result, save_path=None):
show_image_object_detection_auto_result(img_path, result, save_path)

View File

@@ -1,6 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
import inspect import inspect
import os
from pathlib import Path from pathlib import Path
@@ -35,3 +36,10 @@ def get_default_cache_dir():
""" """
default_cache_dir = Path.home().joinpath('.cache', 'modelscope') default_cache_dir = Path.home().joinpath('.cache', 'modelscope')
return default_cache_dir return default_cache_dir
def read_file(path):
with open(path, 'r') as f:
text = f.read()
return text

View File

@@ -4,22 +4,45 @@ import unittest
from modelscope.pipelines import pipeline from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.test_utils import test_level from modelscope.utils.test_utils import test_level
class TinynasObjectDetectionTest(unittest.TestCase): class TinynasObjectDetectionTest(unittest.TestCase, DemoCompatibilityCheck):
def setUp(self) -> None:
self.task = Tasks.image_object_detection
self.model_id = 'damo/cv_tinynas_object-detection_damoyolo'
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run(self): def test_run_airdet(self):
tinynas_object_detection = pipeline( tinynas_object_detection = pipeline(
Tasks.image_object_detection, model='damo/cv_tinynas_detection') Tasks.image_object_detection, model='damo/cv_tinynas_detection')
result = tinynas_object_detection( result = tinynas_object_detection(
'data/test/images/image_detection.jpg') 'data/test/images/image_detection.jpg')
print(result) print(result)
@unittest.skip('will be enabled after damoyolo officially released')
def test_run_damoyolo(self):
tinynas_object_detection = pipeline(
Tasks.image_object_detection,
model='damo/cv_tinynas_object-detection_damoyolo')
result = tinynas_object_detection(
'data/test/images/image_detection.jpg')
print(result)
@unittest.skip('demo compatibility test is only enabled on a needed-basis') @unittest.skip('demo compatibility test is only enabled on a needed-basis')
def test_demo_compatibility(self): def test_demo_compatibility(self):
self.test_demo() self.compatibility_check()
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_image_object_detection_auto_pipeline(self):
test_image = 'data/test/images/image_detection.jpg'
tinynas_object_detection = pipeline(
Tasks.image_object_detection, model='damo/cv_tinynas_detection')
result = tinynas_object_detection(test_image)
tinynas_object_detection.show_result(test_image, result,
'demo_ret.jpg')
if __name__ == '__main__': if __name__ == '__main__':