mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-16 08:27:49 +01:00
66 lines
2.4 KiB
Python
66 lines
2.4 KiB
Python
# Credit: https://github.com/VainF/DeepLabV3Plus-Pytorch
|
|
|
|
from .utils import IntermediateLayerGetter
|
|
from ._deeplab import DeepLabHead, DeepLabHeadV3Plus, DeepLabV3
|
|
from . import s2m_resnet
|
|
|
|
def _segm_resnet(name, backbone_name, num_classes, output_stride, pretrained_backbone):
|
|
|
|
if output_stride==8:
|
|
replace_stride_with_dilation=[False, True, True]
|
|
aspp_dilate = [12, 24, 36]
|
|
else:
|
|
replace_stride_with_dilation=[False, False, True]
|
|
aspp_dilate = [6, 12, 18]
|
|
|
|
backbone = s2m_resnet.__dict__[backbone_name](
|
|
pretrained=pretrained_backbone,
|
|
replace_stride_with_dilation=replace_stride_with_dilation)
|
|
|
|
inplanes = 2048
|
|
low_level_planes = 256
|
|
|
|
if name=='deeplabv3plus':
|
|
return_layers = {'layer4': 'out', 'layer1': 'low_level'}
|
|
classifier = DeepLabHeadV3Plus(inplanes, low_level_planes, num_classes, aspp_dilate)
|
|
elif name=='deeplabv3':
|
|
return_layers = {'layer4': 'out'}
|
|
classifier = DeepLabHead(inplanes , num_classes, aspp_dilate)
|
|
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
|
|
|
|
model = DeepLabV3(backbone, classifier)
|
|
return model
|
|
|
|
def _load_model(arch_type, backbone, num_classes, output_stride, pretrained_backbone):
|
|
|
|
if backbone.startswith('resnet'):
|
|
model = _segm_resnet(arch_type, backbone, num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)
|
|
else:
|
|
raise NotImplementedError
|
|
return model
|
|
|
|
|
|
# Deeplab v3
|
|
def deeplabv3_resnet50(num_classes=1, output_stride=16, pretrained_backbone=False):
|
|
"""Constructs a DeepLabV3 model with a ResNet-50 backbone.
|
|
|
|
Args:
|
|
num_classes (int): number of classes.
|
|
output_stride (int): output stride for deeplab.
|
|
pretrained_backbone (bool): If True, use the pretrained backbone.
|
|
"""
|
|
return _load_model('deeplabv3', 'resnet50', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)
|
|
|
|
|
|
# Deeplab v3+
|
|
def deeplabv3plus_resnet50(num_classes=1, output_stride=16, pretrained_backbone=False):
|
|
"""Constructs a DeepLabV3 model with a ResNet-50 backbone.
|
|
|
|
Args:
|
|
num_classes (int): number of classes.
|
|
output_stride (int): output stride for deeplab.
|
|
pretrained_backbone (bool): If True, use the pretrained backbone.
|
|
"""
|
|
return _load_model('deeplabv3plus', 'resnet50', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)
|
|
|