mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
[to #42322933]add damoyolo model in tinynas-object-detection
接入damyolo系列检测模型
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10377688
This commit is contained in:
@@ -9,7 +9,9 @@ class Models(object):
|
|||||||
|
|
||||||
Model name should only contain model info but not task info.
|
Model name should only contain model info but not task info.
|
||||||
"""
|
"""
|
||||||
|
# tinynas models
|
||||||
tinynas_detection = 'tinynas-detection'
|
tinynas_detection = 'tinynas-detection'
|
||||||
|
tinynas_damoyolo = 'tinynas-damoyolo'
|
||||||
|
|
||||||
# vision models
|
# vision models
|
||||||
detection = 'detection'
|
detection = 'detection'
|
||||||
|
|||||||
@@ -7,10 +7,12 @@ from modelscope.utils.import_utils import LazyImportModule
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .tinynas_detector import Tinynas_detector
|
from .tinynas_detector import Tinynas_detector
|
||||||
|
from .tinynas_damoyolo import DamoYolo
|
||||||
|
|
||||||
else:
|
else:
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
'tinynas_detector': ['TinynasDetector'],
|
'tinynas_detector': ['TinynasDetector'],
|
||||||
|
'tinynas_damoyolo': ['DamoYolo'],
|
||||||
}
|
}
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from modelscope.utils.file_utils import read_file
|
||||||
from ..core.base_ops import Focus, SPPBottleneck, get_activation
|
from ..core.base_ops import Focus, SPPBottleneck, get_activation
|
||||||
from ..core.repvgg_block import RepVggBlock
|
from ..core.repvgg_block import RepVggBlock
|
||||||
|
|
||||||
@@ -49,12 +50,16 @@ class ResConvK1KX(nn.Module):
|
|||||||
kernel_size,
|
kernel_size,
|
||||||
stride,
|
stride,
|
||||||
force_resproj=False,
|
force_resproj=False,
|
||||||
act='silu'):
|
act='silu',
|
||||||
|
reparam=False):
|
||||||
super(ResConvK1KX, self).__init__()
|
super(ResConvK1KX, self).__init__()
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
self.conv1 = ConvKXBN(in_c, btn_c, 1, 1)
|
self.conv1 = ConvKXBN(in_c, btn_c, 1, 1)
|
||||||
self.conv2 = RepVggBlock(
|
if not reparam:
|
||||||
btn_c, out_c, kernel_size, stride, act='identity')
|
self.conv2 = ConvKXBN(btn_c, out_c, 3, stride)
|
||||||
|
else:
|
||||||
|
self.conv2 = RepVggBlock(
|
||||||
|
btn_c, out_c, kernel_size, stride, act='identity')
|
||||||
|
|
||||||
if act is None:
|
if act is None:
|
||||||
self.activation_function = torch.relu
|
self.activation_function = torch.relu
|
||||||
@@ -97,7 +102,8 @@ class SuperResConvK1KX(nn.Module):
|
|||||||
stride,
|
stride,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
with_spp=False,
|
with_spp=False,
|
||||||
act='silu'):
|
act='silu',
|
||||||
|
reparam=False):
|
||||||
super(SuperResConvK1KX, self).__init__()
|
super(SuperResConvK1KX, self).__init__()
|
||||||
if act is None:
|
if act is None:
|
||||||
self.act = torch.relu
|
self.act = torch.relu
|
||||||
@@ -124,7 +130,8 @@ class SuperResConvK1KX(nn.Module):
|
|||||||
this_kernel_size,
|
this_kernel_size,
|
||||||
this_stride,
|
this_stride,
|
||||||
force_resproj,
|
force_resproj,
|
||||||
act=act)
|
act=act,
|
||||||
|
reparam=reparam)
|
||||||
self.block_list.append(the_block)
|
self.block_list.append(the_block)
|
||||||
if block_id == 0 and with_spp:
|
if block_id == 0 and with_spp:
|
||||||
self.block_list.append(
|
self.block_list.append(
|
||||||
@@ -248,7 +255,8 @@ class TinyNAS(nn.Module):
|
|||||||
with_spp=False,
|
with_spp=False,
|
||||||
use_focus=False,
|
use_focus=False,
|
||||||
need_conv1=True,
|
need_conv1=True,
|
||||||
act='silu'):
|
act='silu',
|
||||||
|
reparam=False):
|
||||||
super(TinyNAS, self).__init__()
|
super(TinyNAS, self).__init__()
|
||||||
assert len(out_indices) == len(out_channels)
|
assert len(out_indices) == len(out_channels)
|
||||||
self.out_indices = out_indices
|
self.out_indices = out_indices
|
||||||
@@ -281,7 +289,8 @@ class TinyNAS(nn.Module):
|
|||||||
block_info['s'],
|
block_info['s'],
|
||||||
block_info['L'],
|
block_info['L'],
|
||||||
spp,
|
spp,
|
||||||
act=act)
|
act=act,
|
||||||
|
reparam=reparam)
|
||||||
self.block_list.append(the_block)
|
self.block_list.append(the_block)
|
||||||
elif the_block_class == 'SuperResConvKXKX':
|
elif the_block_class == 'SuperResConvKXKX':
|
||||||
spp = with_spp if idx == len(structure_info) - 1 else False
|
spp = with_spp if idx == len(structure_info) - 1 else False
|
||||||
@@ -325,8 +334,8 @@ class TinyNAS(nn.Module):
|
|||||||
def load_tinynas_net(backbone_cfg):
|
def load_tinynas_net(backbone_cfg):
|
||||||
# load masternet model to path
|
# load masternet model to path
|
||||||
import ast
|
import ast
|
||||||
|
net_structure_str = read_file(backbone_cfg.structure_file)
|
||||||
struct_str = ''.join([x.strip() for x in backbone_cfg.net_structure_str])
|
struct_str = ''.join([x.strip() for x in net_structure_str])
|
||||||
struct_info = ast.literal_eval(struct_str)
|
struct_info = ast.literal_eval(struct_str)
|
||||||
for layer in struct_info:
|
for layer in struct_info:
|
||||||
if 'nbitsA' in layer:
|
if 'nbitsA' in layer:
|
||||||
@@ -342,6 +351,6 @@ def load_tinynas_net(backbone_cfg):
|
|||||||
use_focus=backbone_cfg.use_focus,
|
use_focus=backbone_cfg.use_focus,
|
||||||
act=backbone_cfg.act,
|
act=backbone_cfg.act,
|
||||||
need_conv1=backbone_cfg.need_conv1,
|
need_conv1=backbone_cfg.need_conv1,
|
||||||
)
|
reparam=backbone_cfg.reparam)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ class SingleStageDetector(TorchModel):
|
|||||||
"""
|
"""
|
||||||
super().__init__(model_dir, *args, **kwargs)
|
super().__init__(model_dir, *args, **kwargs)
|
||||||
|
|
||||||
config_path = osp.join(model_dir, 'airdet_s.py')
|
config_path = osp.join(model_dir, self.config_name)
|
||||||
config = parse_config(config_path)
|
config = parse_config(config_path)
|
||||||
self.cfg = config
|
self.cfg = config
|
||||||
model_path = osp.join(model_dir, config.model.name)
|
model_path = osp.join(model_dir, config.model.name)
|
||||||
@@ -41,6 +41,9 @@ class SingleStageDetector(TorchModel):
|
|||||||
self.conf_thre = config.model.head.nms_conf_thre
|
self.conf_thre = config.model.head.nms_conf_thre
|
||||||
self.nms_thre = config.model.head.nms_iou_thre
|
self.nms_thre = config.model.head.nms_iou_thre
|
||||||
|
|
||||||
|
if self.cfg.model.backbone.name == 'TinyNAS':
|
||||||
|
self.cfg.model.backbone.structure_file = osp.join(
|
||||||
|
model_dir, self.cfg.model.backbone.structure_file)
|
||||||
self.backbone = build_backbone(self.cfg.model.backbone)
|
self.backbone = build_backbone(self.cfg.model.backbone)
|
||||||
self.neck = build_neck(self.cfg.model.neck)
|
self.neck = build_neck(self.cfg.model.neck)
|
||||||
self.head = build_head(self.cfg.model.head)
|
self.head = build_head(self.cfg.model.head)
|
||||||
|
|||||||
@@ -124,11 +124,13 @@ class GFocalHead_Tiny(nn.Module):
|
|||||||
simOTA_iou_weight=3.0,
|
simOTA_iou_weight=3.0,
|
||||||
octbase=8,
|
octbase=8,
|
||||||
simlqe=False,
|
simlqe=False,
|
||||||
|
use_lqe=True,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
self.simlqe = simlqe
|
self.simlqe = simlqe
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.strides = strides
|
self.strides = strides
|
||||||
|
self.use_lqe = use_lqe
|
||||||
self.feat_channels = feat_channels if isinstance(feat_channels, list) \
|
self.feat_channels = feat_channels if isinstance(feat_channels, list) \
|
||||||
else [feat_channels] * len(self.strides)
|
else [feat_channels] * len(self.strides)
|
||||||
|
|
||||||
@@ -181,15 +183,20 @@ class GFocalHead_Tiny(nn.Module):
|
|||||||
groups=self.conv_groups,
|
groups=self.conv_groups,
|
||||||
norm=self.norm,
|
norm=self.norm,
|
||||||
act=self.act))
|
act=self.act))
|
||||||
if not self.simlqe:
|
if self.use_lqe:
|
||||||
conf_vector = [nn.Conv2d(4 * self.total_dim, self.reg_channels, 1)]
|
if not self.simlqe:
|
||||||
|
conf_vector = [
|
||||||
|
nn.Conv2d(4 * self.total_dim, self.reg_channels, 1)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
conf_vector = [
|
||||||
|
nn.Conv2d(4 * (self.reg_max + 1), self.reg_channels, 1)
|
||||||
|
]
|
||||||
|
conf_vector += [self.relu]
|
||||||
|
conf_vector += [nn.Conv2d(self.reg_channels, 1, 1), nn.Sigmoid()]
|
||||||
|
reg_conf = nn.Sequential(*conf_vector)
|
||||||
else:
|
else:
|
||||||
conf_vector = [
|
reg_conf = None
|
||||||
nn.Conv2d(4 * (self.reg_max + 1), self.reg_channels, 1)
|
|
||||||
]
|
|
||||||
conf_vector += [self.relu]
|
|
||||||
conf_vector += [nn.Conv2d(self.reg_channels, 1, 1), nn.Sigmoid()]
|
|
||||||
reg_conf = nn.Sequential(*conf_vector)
|
|
||||||
|
|
||||||
return cls_convs, reg_convs, reg_conf
|
return cls_convs, reg_convs, reg_conf
|
||||||
|
|
||||||
@@ -290,21 +297,27 @@ class GFocalHead_Tiny(nn.Module):
|
|||||||
N, C, H, W = bbox_pred.size()
|
N, C, H, W = bbox_pred.size()
|
||||||
prob = F.softmax(
|
prob = F.softmax(
|
||||||
bbox_pred.reshape(N, 4, self.reg_max + 1, H, W), dim=2)
|
bbox_pred.reshape(N, 4, self.reg_max + 1, H, W), dim=2)
|
||||||
if not self.simlqe:
|
if self.use_lqe:
|
||||||
prob_topk, _ = prob.topk(self.reg_topk, dim=2)
|
if not self.simlqe:
|
||||||
|
prob_topk, _ = prob.topk(self.reg_topk, dim=2)
|
||||||
|
|
||||||
if self.add_mean:
|
if self.add_mean:
|
||||||
stat = torch.cat(
|
stat = torch.cat(
|
||||||
[prob_topk, prob_topk.mean(dim=2, keepdim=True)], dim=2)
|
[prob_topk,
|
||||||
|
prob_topk.mean(dim=2, keepdim=True)],
|
||||||
|
dim=2)
|
||||||
|
else:
|
||||||
|
stat = prob_topk
|
||||||
|
|
||||||
|
quality_score = reg_conf(
|
||||||
|
stat.reshape(N, 4 * self.total_dim, H, W))
|
||||||
else:
|
else:
|
||||||
stat = prob_topk
|
quality_score = reg_conf(
|
||||||
|
bbox_pred.reshape(N, 4 * (self.reg_max + 1), H, W))
|
||||||
|
|
||||||
quality_score = reg_conf(stat.reshape(N, 4 * self.total_dim, H, W))
|
cls_score = gfl_cls(cls_feat).sigmoid() * quality_score
|
||||||
else:
|
else:
|
||||||
quality_score = reg_conf(
|
cls_score = gfl_cls(cls_feat).sigmoid()
|
||||||
bbox_pred.reshape(N, 4 * (self.reg_max + 1), H, W))
|
|
||||||
|
|
||||||
cls_score = gfl_cls(cls_feat).sigmoid() * quality_score
|
|
||||||
|
|
||||||
flatten_cls_score = cls_score.flatten(start_dim=2).transpose(1, 2)
|
flatten_cls_score = cls_score.flatten(start_dim=2).transpose(1, 2)
|
||||||
flatten_bbox_pred = bbox_pred.flatten(start_dim=2).transpose(1, 2)
|
flatten_bbox_pred = bbox_pred.flatten(start_dim=2).transpose(1, 2)
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ class GiraffeNeckV2(nn.Module):
|
|||||||
self,
|
self,
|
||||||
depth=1.0,
|
depth=1.0,
|
||||||
width=1.0,
|
width=1.0,
|
||||||
in_features=[2, 3, 4],
|
|
||||||
in_channels=[256, 512, 1024],
|
in_channels=[256, 512, 1024],
|
||||||
out_channels=[256, 512, 1024],
|
out_channels=[256, 512, 1024],
|
||||||
depthwise=False,
|
depthwise=False,
|
||||||
@@ -24,7 +23,6 @@ class GiraffeNeckV2(nn.Module):
|
|||||||
block_name='BasicBlock',
|
block_name='BasicBlock',
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_features = in_features
|
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
Conv = DWConv if depthwise else BaseConv
|
Conv = DWConv if depthwise else BaseConv
|
||||||
|
|
||||||
@@ -169,8 +167,7 @@ class GiraffeNeckV2(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# backbone
|
# backbone
|
||||||
features = [out_features[f] for f in self.in_features]
|
[x2, x1, x0] = out_features
|
||||||
[x2, x1, x0] = features
|
|
||||||
|
|
||||||
# node x3
|
# node x3
|
||||||
x13 = self.bu_conv13(x1)
|
x13 = self.bu_conv13(x1)
|
||||||
|
|||||||
15
modelscope/models/cv/tinynas_detection/tinynas_damoyolo.py
Normal file
15
modelscope/models/cv/tinynas_detection/tinynas_damoyolo.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
|
||||||
|
from modelscope.metainfo import Models
|
||||||
|
from modelscope.models.builder import MODELS
|
||||||
|
from modelscope.utils.constant import Tasks
|
||||||
|
from .detector import SingleStageDetector
|
||||||
|
|
||||||
|
|
||||||
|
@MODELS.register_module(
|
||||||
|
Tasks.image_object_detection, module_name=Models.tinynas_damoyolo)
|
||||||
|
class DamoYolo(SingleStageDetector):
|
||||||
|
|
||||||
|
def __init__(self, model_dir, *args, **kwargs):
|
||||||
|
self.config_name = 'damoyolo_s.py'
|
||||||
|
super(DamoYolo, self).__init__(model_dir, *args, **kwargs)
|
||||||
@@ -12,5 +12,5 @@ from .detector import SingleStageDetector
|
|||||||
class TinynasDetector(SingleStageDetector):
|
class TinynasDetector(SingleStageDetector):
|
||||||
|
|
||||||
def __init__(self, model_dir, *args, **kwargs):
|
def __init__(self, model_dir, *args, **kwargs):
|
||||||
|
self.config_name = 'airdet_s.py'
|
||||||
super(TinynasDetector, self).__init__(model_dir, *args, **kwargs)
|
super(TinynasDetector, self).__init__(model_dir, *args, **kwargs)
|
||||||
|
|||||||
@@ -12,6 +12,8 @@ from modelscope.pipelines.base import Input, Pipeline
|
|||||||
from modelscope.pipelines.builder import PIPELINES
|
from modelscope.pipelines.builder import PIPELINES
|
||||||
from modelscope.preprocessors import LoadImage
|
from modelscope.preprocessors import LoadImage
|
||||||
from modelscope.utils.constant import Tasks
|
from modelscope.utils.constant import Tasks
|
||||||
|
from modelscope.utils.cv.image_utils import \
|
||||||
|
show_image_object_detection_auto_result
|
||||||
from modelscope.utils.logger import get_logger
|
from modelscope.utils.logger import get_logger
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
@@ -52,10 +54,18 @@ class TinynasDetectionPipeline(Pipeline):
|
|||||||
|
|
||||||
bboxes, scores, labels = self.model.postprocess(inputs['data'])
|
bboxes, scores, labels = self.model.postprocess(inputs['data'])
|
||||||
if bboxes is None:
|
if bboxes is None:
|
||||||
return None
|
outputs = {
|
||||||
outputs = {
|
OutputKeys.SCORES: [],
|
||||||
OutputKeys.SCORES: scores,
|
OutputKeys.LABELS: [],
|
||||||
OutputKeys.LABELS: labels,
|
OutputKeys.BOXES: []
|
||||||
OutputKeys.BOXES: bboxes
|
}
|
||||||
}
|
else:
|
||||||
|
outputs = {
|
||||||
|
OutputKeys.SCORES: scores,
|
||||||
|
OutputKeys.LABELS: labels,
|
||||||
|
OutputKeys.BOXES: bboxes
|
||||||
|
}
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
def show_result(self, img_path, result, save_path=None):
|
||||||
|
show_image_object_detection_auto_result(img_path, result, save_path)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
@@ -35,3 +36,10 @@ def get_default_cache_dir():
|
|||||||
"""
|
"""
|
||||||
default_cache_dir = Path.home().joinpath('.cache', 'modelscope')
|
default_cache_dir = Path.home().joinpath('.cache', 'modelscope')
|
||||||
return default_cache_dir
|
return default_cache_dir
|
||||||
|
|
||||||
|
|
||||||
|
def read_file(path):
|
||||||
|
|
||||||
|
with open(path, 'r') as f:
|
||||||
|
text = f.read()
|
||||||
|
return text
|
||||||
|
|||||||
@@ -4,22 +4,45 @@ import unittest
|
|||||||
|
|
||||||
from modelscope.pipelines import pipeline
|
from modelscope.pipelines import pipeline
|
||||||
from modelscope.utils.constant import Tasks
|
from modelscope.utils.constant import Tasks
|
||||||
|
from modelscope.utils.demo_utils import DemoCompatibilityCheck
|
||||||
from modelscope.utils.test_utils import test_level
|
from modelscope.utils.test_utils import test_level
|
||||||
|
|
||||||
|
|
||||||
class TinynasObjectDetectionTest(unittest.TestCase):
|
class TinynasObjectDetectionTest(unittest.TestCase, DemoCompatibilityCheck):
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.task = Tasks.image_object_detection
|
||||||
|
self.model_id = 'damo/cv_tinynas_object-detection_damoyolo'
|
||||||
|
|
||||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||||
def test_run(self):
|
def test_run_airdet(self):
|
||||||
tinynas_object_detection = pipeline(
|
tinynas_object_detection = pipeline(
|
||||||
Tasks.image_object_detection, model='damo/cv_tinynas_detection')
|
Tasks.image_object_detection, model='damo/cv_tinynas_detection')
|
||||||
result = tinynas_object_detection(
|
result = tinynas_object_detection(
|
||||||
'data/test/images/image_detection.jpg')
|
'data/test/images/image_detection.jpg')
|
||||||
print(result)
|
print(result)
|
||||||
|
|
||||||
|
@unittest.skip('will be enabled after damoyolo officially released')
|
||||||
|
def test_run_damoyolo(self):
|
||||||
|
tinynas_object_detection = pipeline(
|
||||||
|
Tasks.image_object_detection,
|
||||||
|
model='damo/cv_tinynas_object-detection_damoyolo')
|
||||||
|
result = tinynas_object_detection(
|
||||||
|
'data/test/images/image_detection.jpg')
|
||||||
|
print(result)
|
||||||
|
|
||||||
@unittest.skip('demo compatibility test is only enabled on a needed-basis')
|
@unittest.skip('demo compatibility test is only enabled on a needed-basis')
|
||||||
def test_demo_compatibility(self):
|
def test_demo_compatibility(self):
|
||||||
self.test_demo()
|
self.compatibility_check()
|
||||||
|
|
||||||
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||||
|
def test_image_object_detection_auto_pipeline(self):
|
||||||
|
test_image = 'data/test/images/image_detection.jpg'
|
||||||
|
tinynas_object_detection = pipeline(
|
||||||
|
Tasks.image_object_detection, model='damo/cv_tinynas_detection')
|
||||||
|
result = tinynas_object_detection(test_image)
|
||||||
|
tinynas_object_detection.show_result(test_image, result,
|
||||||
|
'demo_ret.jpg')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
Reference in New Issue
Block a user