[to #42322933]add damoyolo model in tinynas-object-detection

接入damyolo系列检测模型
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10377688
This commit is contained in:
xianzhe.xxz
2022-10-18 16:53:29 +08:00
committed by yingda.chen
parent 6438c41144
commit 865397763e
11 changed files with 126 additions and 44 deletions

View File

@@ -9,7 +9,9 @@ class Models(object):
Model name should only contain model info but not task info.
"""
# tinynas models
tinynas_detection = 'tinynas-detection'
tinynas_damoyolo = 'tinynas-damoyolo'
# vision models
detection = 'detection'

View File

@@ -7,10 +7,12 @@ from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .tinynas_detector import Tinynas_detector
from .tinynas_damoyolo import DamoYolo
else:
_import_structure = {
'tinynas_detector': ['TinynasDetector'],
'tinynas_damoyolo': ['DamoYolo'],
}
import sys

View File

@@ -4,6 +4,7 @@
import torch
import torch.nn as nn
from modelscope.utils.file_utils import read_file
from ..core.base_ops import Focus, SPPBottleneck, get_activation
from ..core.repvgg_block import RepVggBlock
@@ -49,12 +50,16 @@ class ResConvK1KX(nn.Module):
kernel_size,
stride,
force_resproj=False,
act='silu'):
act='silu',
reparam=False):
super(ResConvK1KX, self).__init__()
self.stride = stride
self.conv1 = ConvKXBN(in_c, btn_c, 1, 1)
self.conv2 = RepVggBlock(
btn_c, out_c, kernel_size, stride, act='identity')
if not reparam:
self.conv2 = ConvKXBN(btn_c, out_c, 3, stride)
else:
self.conv2 = RepVggBlock(
btn_c, out_c, kernel_size, stride, act='identity')
if act is None:
self.activation_function = torch.relu
@@ -97,7 +102,8 @@ class SuperResConvK1KX(nn.Module):
stride,
num_blocks,
with_spp=False,
act='silu'):
act='silu',
reparam=False):
super(SuperResConvK1KX, self).__init__()
if act is None:
self.act = torch.relu
@@ -124,7 +130,8 @@ class SuperResConvK1KX(nn.Module):
this_kernel_size,
this_stride,
force_resproj,
act=act)
act=act,
reparam=reparam)
self.block_list.append(the_block)
if block_id == 0 and with_spp:
self.block_list.append(
@@ -248,7 +255,8 @@ class TinyNAS(nn.Module):
with_spp=False,
use_focus=False,
need_conv1=True,
act='silu'):
act='silu',
reparam=False):
super(TinyNAS, self).__init__()
assert len(out_indices) == len(out_channels)
self.out_indices = out_indices
@@ -281,7 +289,8 @@ class TinyNAS(nn.Module):
block_info['s'],
block_info['L'],
spp,
act=act)
act=act,
reparam=reparam)
self.block_list.append(the_block)
elif the_block_class == 'SuperResConvKXKX':
spp = with_spp if idx == len(structure_info) - 1 else False
@@ -325,8 +334,8 @@ class TinyNAS(nn.Module):
def load_tinynas_net(backbone_cfg):
# load masternet model to path
import ast
struct_str = ''.join([x.strip() for x in backbone_cfg.net_structure_str])
net_structure_str = read_file(backbone_cfg.structure_file)
struct_str = ''.join([x.strip() for x in net_structure_str])
struct_info = ast.literal_eval(struct_str)
for layer in struct_info:
if 'nbitsA' in layer:
@@ -342,6 +351,6 @@ def load_tinynas_net(backbone_cfg):
use_focus=backbone_cfg.use_focus,
act=backbone_cfg.act,
need_conv1=backbone_cfg.need_conv1,
)
reparam=backbone_cfg.reparam)
return model

View File

@@ -30,7 +30,7 @@ class SingleStageDetector(TorchModel):
"""
super().__init__(model_dir, *args, **kwargs)
config_path = osp.join(model_dir, 'airdet_s.py')
config_path = osp.join(model_dir, self.config_name)
config = parse_config(config_path)
self.cfg = config
model_path = osp.join(model_dir, config.model.name)
@@ -41,6 +41,9 @@ class SingleStageDetector(TorchModel):
self.conf_thre = config.model.head.nms_conf_thre
self.nms_thre = config.model.head.nms_iou_thre
if self.cfg.model.backbone.name == 'TinyNAS':
self.cfg.model.backbone.structure_file = osp.join(
model_dir, self.cfg.model.backbone.structure_file)
self.backbone = build_backbone(self.cfg.model.backbone)
self.neck = build_neck(self.cfg.model.neck)
self.head = build_head(self.cfg.model.head)

View File

@@ -124,11 +124,13 @@ class GFocalHead_Tiny(nn.Module):
simOTA_iou_weight=3.0,
octbase=8,
simlqe=False,
use_lqe=True,
**kwargs):
self.simlqe = simlqe
self.num_classes = num_classes
self.in_channels = in_channels
self.strides = strides
self.use_lqe = use_lqe
self.feat_channels = feat_channels if isinstance(feat_channels, list) \
else [feat_channels] * len(self.strides)
@@ -181,15 +183,20 @@ class GFocalHead_Tiny(nn.Module):
groups=self.conv_groups,
norm=self.norm,
act=self.act))
if not self.simlqe:
conf_vector = [nn.Conv2d(4 * self.total_dim, self.reg_channels, 1)]
if self.use_lqe:
if not self.simlqe:
conf_vector = [
nn.Conv2d(4 * self.total_dim, self.reg_channels, 1)
]
else:
conf_vector = [
nn.Conv2d(4 * (self.reg_max + 1), self.reg_channels, 1)
]
conf_vector += [self.relu]
conf_vector += [nn.Conv2d(self.reg_channels, 1, 1), nn.Sigmoid()]
reg_conf = nn.Sequential(*conf_vector)
else:
conf_vector = [
nn.Conv2d(4 * (self.reg_max + 1), self.reg_channels, 1)
]
conf_vector += [self.relu]
conf_vector += [nn.Conv2d(self.reg_channels, 1, 1), nn.Sigmoid()]
reg_conf = nn.Sequential(*conf_vector)
reg_conf = None
return cls_convs, reg_convs, reg_conf
@@ -290,21 +297,27 @@ class GFocalHead_Tiny(nn.Module):
N, C, H, W = bbox_pred.size()
prob = F.softmax(
bbox_pred.reshape(N, 4, self.reg_max + 1, H, W), dim=2)
if not self.simlqe:
prob_topk, _ = prob.topk(self.reg_topk, dim=2)
if self.use_lqe:
if not self.simlqe:
prob_topk, _ = prob.topk(self.reg_topk, dim=2)
if self.add_mean:
stat = torch.cat(
[prob_topk, prob_topk.mean(dim=2, keepdim=True)], dim=2)
if self.add_mean:
stat = torch.cat(
[prob_topk,
prob_topk.mean(dim=2, keepdim=True)],
dim=2)
else:
stat = prob_topk
quality_score = reg_conf(
stat.reshape(N, 4 * self.total_dim, H, W))
else:
stat = prob_topk
quality_score = reg_conf(
bbox_pred.reshape(N, 4 * (self.reg_max + 1), H, W))
quality_score = reg_conf(stat.reshape(N, 4 * self.total_dim, H, W))
cls_score = gfl_cls(cls_feat).sigmoid() * quality_score
else:
quality_score = reg_conf(
bbox_pred.reshape(N, 4 * (self.reg_max + 1), H, W))
cls_score = gfl_cls(cls_feat).sigmoid() * quality_score
cls_score = gfl_cls(cls_feat).sigmoid()
flatten_cls_score = cls_score.flatten(start_dim=2).transpose(1, 2)
flatten_bbox_pred = bbox_pred.flatten(start_dim=2).transpose(1, 2)

View File

@@ -14,7 +14,6 @@ class GiraffeNeckV2(nn.Module):
self,
depth=1.0,
width=1.0,
in_features=[2, 3, 4],
in_channels=[256, 512, 1024],
out_channels=[256, 512, 1024],
depthwise=False,
@@ -24,7 +23,6 @@ class GiraffeNeckV2(nn.Module):
block_name='BasicBlock',
):
super().__init__()
self.in_features = in_features
self.in_channels = in_channels
Conv = DWConv if depthwise else BaseConv
@@ -169,8 +167,7 @@ class GiraffeNeckV2(nn.Module):
"""
# backbone
features = [out_features[f] for f in self.in_features]
[x2, x1, x0] = features
[x2, x1, x0] = out_features
# node x3
x13 = self.bu_conv13(x1)

View File

@@ -0,0 +1,15 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from modelscope.metainfo import Models
from modelscope.models.builder import MODELS
from modelscope.utils.constant import Tasks
from .detector import SingleStageDetector
@MODELS.register_module(
Tasks.image_object_detection, module_name=Models.tinynas_damoyolo)
class DamoYolo(SingleStageDetector):
def __init__(self, model_dir, *args, **kwargs):
self.config_name = 'damoyolo_s.py'
super(DamoYolo, self).__init__(model_dir, *args, **kwargs)

View File

@@ -12,5 +12,5 @@ from .detector import SingleStageDetector
class TinynasDetector(SingleStageDetector):
def __init__(self, model_dir, *args, **kwargs):
self.config_name = 'airdet_s.py'
super(TinynasDetector, self).__init__(model_dir, *args, **kwargs)

View File

@@ -12,6 +12,8 @@ from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import LoadImage
from modelscope.utils.constant import Tasks
from modelscope.utils.cv.image_utils import \
show_image_object_detection_auto_result
from modelscope.utils.logger import get_logger
logger = get_logger()
@@ -52,10 +54,18 @@ class TinynasDetectionPipeline(Pipeline):
bboxes, scores, labels = self.model.postprocess(inputs['data'])
if bboxes is None:
return None
outputs = {
OutputKeys.SCORES: scores,
OutputKeys.LABELS: labels,
OutputKeys.BOXES: bboxes
}
outputs = {
OutputKeys.SCORES: [],
OutputKeys.LABELS: [],
OutputKeys.BOXES: []
}
else:
outputs = {
OutputKeys.SCORES: scores,
OutputKeys.LABELS: labels,
OutputKeys.BOXES: bboxes
}
return outputs
def show_result(self, img_path, result, save_path=None):
show_image_object_detection_auto_result(img_path, result, save_path)

View File

@@ -1,6 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import inspect
import os
from pathlib import Path
@@ -35,3 +36,10 @@ def get_default_cache_dir():
"""
default_cache_dir = Path.home().joinpath('.cache', 'modelscope')
return default_cache_dir
def read_file(path):
with open(path, 'r') as f:
text = f.read()
return text

View File

@@ -4,22 +4,45 @@ import unittest
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.test_utils import test_level
class TinynasObjectDetectionTest(unittest.TestCase):
class TinynasObjectDetectionTest(unittest.TestCase, DemoCompatibilityCheck):
def setUp(self) -> None:
self.task = Tasks.image_object_detection
self.model_id = 'damo/cv_tinynas_object-detection_damoyolo'
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run(self):
def test_run_airdet(self):
tinynas_object_detection = pipeline(
Tasks.image_object_detection, model='damo/cv_tinynas_detection')
result = tinynas_object_detection(
'data/test/images/image_detection.jpg')
print(result)
@unittest.skip('will be enabled after damoyolo officially released')
def test_run_damoyolo(self):
tinynas_object_detection = pipeline(
Tasks.image_object_detection,
model='damo/cv_tinynas_object-detection_damoyolo')
result = tinynas_object_detection(
'data/test/images/image_detection.jpg')
print(result)
@unittest.skip('demo compatibility test is only enabled on a needed-basis')
def test_demo_compatibility(self):
self.test_demo()
self.compatibility_check()
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_image_object_detection_auto_pipeline(self):
test_image = 'data/test/images/image_detection.jpg'
tinynas_object_detection = pipeline(
Tasks.image_object_detection, model='damo/cv_tinynas_detection')
result = tinynas_object_detection(test_image)
tinynas_object_detection.show_result(test_image, result,
'demo_ret.jpg')
if __name__ == '__main__':