mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 04:01:10 +01:00
[to #42322933] 新增RetinaFace人脸检测器
1. 新增人脸检测RetinaFace模型;
2. 完成Maas-cv CR标准自查
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9945188
This commit is contained in:
3
data/test/images/retina_face_detection.jpg
Normal file
3
data/test/images/retina_face_detection.jpg
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:176c824d99af119b36f743d3d90b44529167b0e4fc6db276da60fa140ee3f4a9
|
||||
size 87228
|
||||
@@ -32,6 +32,7 @@ class Models(object):
|
||||
vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation'
|
||||
text_driven_segmentation = 'text-driven-segmentation'
|
||||
resnet50_bert = 'resnet50-bert'
|
||||
retinaface = 'retinaface'
|
||||
shop_segmentation = 'shop-segmentation'
|
||||
|
||||
# EasyCV models
|
||||
@@ -118,6 +119,7 @@ class Pipelines(object):
|
||||
salient_detection = 'u2net-salient-detection'
|
||||
image_classification = 'image-classification'
|
||||
face_detection = 'resnet-face-detection-scrfd10gkps'
|
||||
retina_face_detection = 'resnet50-face-detection-retinaface'
|
||||
live_category = 'live-category'
|
||||
general_image_classification = 'vit-base_image-classification_ImageNet-labels'
|
||||
daily_image_classification = 'vit-base_image-classification_Dailylife-labels'
|
||||
|
||||
137
modelscope/models/cv/face_detection/retinaface/detection.py
Executable file
137
modelscope/models/cv/face_detection/retinaface/detection.py
Executable file
@@ -0,0 +1,137 @@
|
||||
# The implementation is based on resnet, available at https://github.com/biubug6/Pytorch_Retinaface
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import Tensor, TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from .models.retinaface import RetinaFace
|
||||
from .utils import PriorBox, decode, decode_landm, py_cpu_nms
|
||||
|
||||
|
||||
@MODELS.register_module(Tasks.face_detection, module_name=Models.retinaface)
|
||||
class RetinaFaceDetection(TorchModel):
|
||||
|
||||
def __init__(self, model_path, device='cuda'):
|
||||
super().__init__(model_path)
|
||||
torch.set_grad_enabled(False)
|
||||
cudnn.benchmark = True
|
||||
self.model_path = model_path
|
||||
self.cfg = Config.from_file(
|
||||
model_path.replace(ModelFile.TORCH_MODEL_FILE,
|
||||
ModelFile.CONFIGURATION))['models']
|
||||
self.net = RetinaFace(cfg=self.cfg)
|
||||
self.load_model()
|
||||
self.device = device
|
||||
self.net = self.net.to(self.device)
|
||||
|
||||
self.mean = torch.tensor([[[[104]], [[117]], [[123]]]]).to(device)
|
||||
|
||||
def check_keys(self, pretrained_state_dict):
|
||||
ckpt_keys = set(pretrained_state_dict.keys())
|
||||
model_keys = set(self.net.state_dict().keys())
|
||||
used_pretrained_keys = model_keys & ckpt_keys
|
||||
assert len(
|
||||
used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint'
|
||||
return True
|
||||
|
||||
def remove_prefix(self, state_dict, prefix):
|
||||
new_state_dict = dict()
|
||||
for k, v in state_dict.items():
|
||||
if k.startswith(prefix):
|
||||
new_state_dict[k[len(prefix):]] = v
|
||||
else:
|
||||
new_state_dict[k] = v
|
||||
return new_state_dict
|
||||
|
||||
def load_model(self, load_to_cpu=False):
|
||||
pretrained_dict = torch.load(
|
||||
self.model_path, map_location=torch.device('cpu'))
|
||||
if 'state_dict' in pretrained_dict.keys():
|
||||
pretrained_dict = self.remove_prefix(pretrained_dict['state_dict'],
|
||||
'module.')
|
||||
else:
|
||||
pretrained_dict = self.remove_prefix(pretrained_dict, 'module.')
|
||||
self.check_keys(pretrained_dict)
|
||||
self.net.load_state_dict(pretrained_dict, strict=False)
|
||||
self.net.eval()
|
||||
|
||||
def forward(self, input):
|
||||
img_raw = input['img'].cpu().numpy()
|
||||
img = np.float32(img_raw)
|
||||
|
||||
im_height, im_width = img.shape[:2]
|
||||
ss = 1.0
|
||||
# tricky
|
||||
if max(im_height, im_width) > 1500:
|
||||
ss = 1000.0 / max(im_height, im_width)
|
||||
img = cv2.resize(img, (0, 0), fx=ss, fy=ss)
|
||||
im_height, im_width = img.shape[:2]
|
||||
|
||||
scale = torch.Tensor(
|
||||
[img.shape[1], img.shape[0], img.shape[1], img.shape[0]])
|
||||
img -= (104, 117, 123)
|
||||
img = img.transpose(2, 0, 1)
|
||||
img = torch.from_numpy(img).unsqueeze(0)
|
||||
img = img.to(self.device)
|
||||
scale = scale.to(self.device)
|
||||
|
||||
loc, conf, landms = self.net(img) # forward pass
|
||||
del img
|
||||
|
||||
confidence_threshold = 0.9
|
||||
nms_threshold = 0.4
|
||||
top_k = 5000
|
||||
keep_top_k = 750
|
||||
|
||||
priorbox = PriorBox(self.cfg, image_size=(im_height, im_width))
|
||||
priors = priorbox.forward()
|
||||
priors = priors.to(self.device)
|
||||
prior_data = priors.data
|
||||
boxes = decode(loc.data.squeeze(0), prior_data, self.cfg['variance'])
|
||||
boxes = boxes * scale
|
||||
boxes = boxes.cpu().numpy()
|
||||
scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
|
||||
landms = decode_landm(
|
||||
landms.data.squeeze(0), prior_data, self.cfg['variance'])
|
||||
scale1 = torch.Tensor([
|
||||
im_width, im_height, im_width, im_height, im_width, im_height,
|
||||
im_width, im_height, im_width, im_height
|
||||
])
|
||||
scale1 = scale1.to(self.device)
|
||||
landms = landms * scale1
|
||||
landms = landms.cpu().numpy()
|
||||
|
||||
# ignore low scores
|
||||
inds = np.where(scores > confidence_threshold)[0]
|
||||
boxes = boxes[inds]
|
||||
landms = landms[inds]
|
||||
scores = scores[inds]
|
||||
|
||||
# keep top-K before NMS
|
||||
order = scores.argsort()[::-1][:top_k]
|
||||
boxes = boxes[order]
|
||||
landms = landms[order]
|
||||
scores = scores[order]
|
||||
|
||||
# do NMS
|
||||
dets = np.hstack((boxes, scores[:, np.newaxis])).astype(
|
||||
np.float32, copy=False)
|
||||
keep = py_cpu_nms(dets, nms_threshold)
|
||||
dets = dets[keep, :]
|
||||
landms = landms[keep]
|
||||
|
||||
# keep top-K faster NMS
|
||||
dets = dets[:keep_top_k, :]
|
||||
landms = landms[:keep_top_k, :]
|
||||
|
||||
landms = landms.reshape((-1, 5, 2))
|
||||
landms = landms.reshape(
|
||||
-1,
|
||||
10,
|
||||
)
|
||||
return dets / ss, landms / ss
|
||||
0
modelscope/models/cv/face_detection/retinaface/models/__init__.py
Executable file
0
modelscope/models/cv/face_detection/retinaface/models/__init__.py
Executable file
149
modelscope/models/cv/face_detection/retinaface/models/net.py
Executable file
149
modelscope/models/cv/face_detection/retinaface/models/net.py
Executable file
@@ -0,0 +1,149 @@
|
||||
# The implementation is based on resnet, available at https://github.com/biubug6/Pytorch_Retinaface
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision.models as models
|
||||
import torchvision.models._utils as _utils
|
||||
from torch.autograd import Variable
|
||||
|
||||
|
||||
def conv_bn(inp, oup, stride=1, leaky=0):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup),
|
||||
nn.LeakyReLU(negative_slope=leaky, inplace=True))
|
||||
|
||||
|
||||
def conv_bn_no_relu(inp, oup, stride):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
)
|
||||
|
||||
|
||||
def conv_bn1X1(inp, oup, stride, leaky=0):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False),
|
||||
nn.BatchNorm2d(oup), nn.LeakyReLU(negative_slope=leaky, inplace=True))
|
||||
|
||||
|
||||
def conv_dw(inp, oup, stride, leaky=0.1):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
|
||||
nn.BatchNorm2d(inp),
|
||||
nn.LeakyReLU(negative_slope=leaky, inplace=True),
|
||||
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
nn.LeakyReLU(negative_slope=leaky, inplace=True),
|
||||
)
|
||||
|
||||
|
||||
class SSH(nn.Module):
|
||||
|
||||
def __init__(self, in_channel, out_channel):
|
||||
super(SSH, self).__init__()
|
||||
assert out_channel % 4 == 0
|
||||
leaky = 0
|
||||
if (out_channel <= 64):
|
||||
leaky = 0.1
|
||||
self.conv3X3 = conv_bn_no_relu(in_channel, out_channel // 2, stride=1)
|
||||
|
||||
self.conv5X5_1 = conv_bn(
|
||||
in_channel, out_channel // 4, stride=1, leaky=leaky)
|
||||
self.conv5X5_2 = conv_bn_no_relu(
|
||||
out_channel // 4, out_channel // 4, stride=1)
|
||||
|
||||
self.conv7X7_2 = conv_bn(
|
||||
out_channel // 4, out_channel // 4, stride=1, leaky=leaky)
|
||||
self.conv7x7_3 = conv_bn_no_relu(
|
||||
out_channel // 4, out_channel // 4, stride=1)
|
||||
|
||||
def forward(self, input):
|
||||
conv3X3 = self.conv3X3(input)
|
||||
|
||||
conv5X5_1 = self.conv5X5_1(input)
|
||||
conv5X5 = self.conv5X5_2(conv5X5_1)
|
||||
|
||||
conv7X7_2 = self.conv7X7_2(conv5X5_1)
|
||||
conv7X7 = self.conv7x7_3(conv7X7_2)
|
||||
|
||||
out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1)
|
||||
out = F.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class FPN(nn.Module):
|
||||
|
||||
def __init__(self, in_channels_list, out_channels):
|
||||
super(FPN, self).__init__()
|
||||
leaky = 0
|
||||
if (out_channels <= 64):
|
||||
leaky = 0.1
|
||||
self.output1 = conv_bn1X1(
|
||||
in_channels_list[0], out_channels, stride=1, leaky=leaky)
|
||||
self.output2 = conv_bn1X1(
|
||||
in_channels_list[1], out_channels, stride=1, leaky=leaky)
|
||||
self.output3 = conv_bn1X1(
|
||||
in_channels_list[2], out_channels, stride=1, leaky=leaky)
|
||||
|
||||
self.merge1 = conv_bn(out_channels, out_channels, leaky=leaky)
|
||||
self.merge2 = conv_bn(out_channels, out_channels, leaky=leaky)
|
||||
|
||||
def forward(self, input):
|
||||
# names = list(input.keys())
|
||||
input = list(input.values())
|
||||
|
||||
output1 = self.output1(input[0])
|
||||
output2 = self.output2(input[1])
|
||||
output3 = self.output3(input[2])
|
||||
|
||||
up3 = F.interpolate(
|
||||
output3, size=[output2.size(2), output2.size(3)], mode='nearest')
|
||||
output2 = output2 + up3
|
||||
output2 = self.merge2(output2)
|
||||
|
||||
up2 = F.interpolate(
|
||||
output2, size=[output1.size(2), output1.size(3)], mode='nearest')
|
||||
output1 = output1 + up2
|
||||
output1 = self.merge1(output1)
|
||||
|
||||
out = [output1, output2, output3]
|
||||
return out
|
||||
|
||||
|
||||
class MobileNetV1(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(MobileNetV1, self).__init__()
|
||||
self.stage1 = nn.Sequential(
|
||||
conv_bn(3, 8, 2, leaky=0.1), # 3
|
||||
conv_dw(8, 16, 1), # 7
|
||||
conv_dw(16, 32, 2), # 11
|
||||
conv_dw(32, 32, 1), # 19
|
||||
conv_dw(32, 64, 2), # 27
|
||||
conv_dw(64, 64, 1), # 43
|
||||
)
|
||||
self.stage2 = nn.Sequential(
|
||||
conv_dw(64, 128, 2), # 43 + 16 = 59
|
||||
conv_dw(128, 128, 1), # 59 + 32 = 91
|
||||
conv_dw(128, 128, 1), # 91 + 32 = 123
|
||||
conv_dw(128, 128, 1), # 123 + 32 = 155
|
||||
conv_dw(128, 128, 1), # 155 + 32 = 187
|
||||
conv_dw(128, 128, 1), # 187 + 32 = 219
|
||||
)
|
||||
self.stage3 = nn.Sequential(
|
||||
conv_dw(128, 256, 2), # 219 +3 2 = 241
|
||||
conv_dw(256, 256, 1), # 241 + 64 = 301
|
||||
)
|
||||
self.avg = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(256, 1000)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.stage1(x)
|
||||
x = self.stage2(x)
|
||||
x = self.stage3(x)
|
||||
x = self.avg(x)
|
||||
x = x.view(-1, 256)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
145
modelscope/models/cv/face_detection/retinaface/models/retinaface.py
Executable file
145
modelscope/models/cv/face_detection/retinaface/models/retinaface.py
Executable file
@@ -0,0 +1,145 @@
|
||||
# The implementation is based on resnet, available at https://github.com/biubug6/Pytorch_Retinaface
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision.models as models
|
||||
import torchvision.models._utils as _utils
|
||||
import torchvision.models.detection.backbone_utils as backbone_utils
|
||||
|
||||
from .net import FPN, SSH, MobileNetV1
|
||||
|
||||
|
||||
class ClassHead(nn.Module):
|
||||
|
||||
def __init__(self, inchannels=512, num_anchors=3):
|
||||
super(ClassHead, self).__init__()
|
||||
self.num_anchors = num_anchors
|
||||
self.conv1x1 = nn.Conv2d(
|
||||
inchannels,
|
||||
self.num_anchors * 2,
|
||||
kernel_size=(1, 1),
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1x1(x)
|
||||
out = out.permute(0, 2, 3, 1).contiguous()
|
||||
|
||||
return out.view(out.shape[0], -1, 2)
|
||||
|
||||
|
||||
class BboxHead(nn.Module):
|
||||
|
||||
def __init__(self, inchannels=512, num_anchors=3):
|
||||
super(BboxHead, self).__init__()
|
||||
self.conv1x1 = nn.Conv2d(
|
||||
inchannels,
|
||||
num_anchors * 4,
|
||||
kernel_size=(1, 1),
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1x1(x)
|
||||
out = out.permute(0, 2, 3, 1).contiguous()
|
||||
|
||||
return out.view(out.shape[0], -1, 4)
|
||||
|
||||
|
||||
class LandmarkHead(nn.Module):
|
||||
|
||||
def __init__(self, inchannels=512, num_anchors=3):
|
||||
super(LandmarkHead, self).__init__()
|
||||
self.conv1x1 = nn.Conv2d(
|
||||
inchannels,
|
||||
num_anchors * 10,
|
||||
kernel_size=(1, 1),
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1x1(x)
|
||||
out = out.permute(0, 2, 3, 1).contiguous()
|
||||
|
||||
return out.view(out.shape[0], -1, 10)
|
||||
|
||||
|
||||
class RetinaFace(nn.Module):
|
||||
|
||||
def __init__(self, cfg=None):
|
||||
"""
|
||||
:param cfg: Network related settings.
|
||||
"""
|
||||
super(RetinaFace, self).__init__()
|
||||
backbone = None
|
||||
if cfg['name'] == 'Resnet50':
|
||||
backbone = models.resnet50(pretrained=cfg['pretrain'])
|
||||
else:
|
||||
raise Exception('Invalid name')
|
||||
|
||||
self.body = _utils.IntermediateLayerGetter(backbone,
|
||||
cfg['return_layers'])
|
||||
in_channels_stage2 = cfg['in_channel']
|
||||
in_channels_list = [
|
||||
in_channels_stage2 * 2,
|
||||
in_channels_stage2 * 4,
|
||||
in_channels_stage2 * 8,
|
||||
]
|
||||
out_channels = cfg['out_channel']
|
||||
self.fpn = FPN(in_channels_list, out_channels)
|
||||
self.ssh1 = SSH(out_channels, out_channels)
|
||||
self.ssh2 = SSH(out_channels, out_channels)
|
||||
self.ssh3 = SSH(out_channels, out_channels)
|
||||
|
||||
self.ClassHead = self._make_class_head(
|
||||
fpn_num=3, inchannels=cfg['out_channel'])
|
||||
self.BboxHead = self._make_bbox_head(
|
||||
fpn_num=3, inchannels=cfg['out_channel'])
|
||||
self.LandmarkHead = self._make_landmark_head(
|
||||
fpn_num=3, inchannels=cfg['out_channel'])
|
||||
|
||||
def _make_class_head(self, fpn_num=3, inchannels=64, anchor_num=2):
|
||||
classhead = nn.ModuleList()
|
||||
for i in range(fpn_num):
|
||||
classhead.append(ClassHead(inchannels, anchor_num))
|
||||
return classhead
|
||||
|
||||
def _make_bbox_head(self, fpn_num=3, inchannels=64, anchor_num=2):
|
||||
bboxhead = nn.ModuleList()
|
||||
for i in range(fpn_num):
|
||||
bboxhead.append(BboxHead(inchannels, anchor_num))
|
||||
return bboxhead
|
||||
|
||||
def _make_landmark_head(self, fpn_num=3, inchannels=64, anchor_num=2):
|
||||
landmarkhead = nn.ModuleList()
|
||||
for i in range(fpn_num):
|
||||
landmarkhead.append(LandmarkHead(inchannels, anchor_num))
|
||||
return landmarkhead
|
||||
|
||||
def forward(self, inputs):
|
||||
out = self.body(inputs)
|
||||
|
||||
# FPN
|
||||
fpn = self.fpn(out)
|
||||
|
||||
# SSH
|
||||
feature1 = self.ssh1(fpn[0])
|
||||
feature2 = self.ssh2(fpn[1])
|
||||
feature3 = self.ssh3(fpn[2])
|
||||
features = [feature1, feature2, feature3]
|
||||
|
||||
bbox_regressions = torch.cat(
|
||||
[self.BboxHead[i](feature) for i, feature in enumerate(features)],
|
||||
dim=1)
|
||||
classifications = torch.cat(
|
||||
[self.ClassHead[i](feature) for i, feature in enumerate(features)],
|
||||
dim=1)
|
||||
ldm_regressions = torch.cat(
|
||||
[self.LandmarkHead[i](feat) for i, feat in enumerate(features)],
|
||||
dim=1)
|
||||
|
||||
output = (bbox_regressions, F.softmax(classifications,
|
||||
dim=-1), ldm_regressions)
|
||||
return output
|
||||
123
modelscope/models/cv/face_detection/retinaface/utils.py
Executable file
123
modelscope/models/cv/face_detection/retinaface/utils.py
Executable file
@@ -0,0 +1,123 @@
|
||||
# --------------------------------------------------------
|
||||
# Modified from https://github.com/biubug6/Pytorch_Retinaface
|
||||
# --------------------------------------------------------
|
||||
|
||||
from itertools import product as product
|
||||
from math import ceil
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class PriorBox(object):
|
||||
|
||||
def __init__(self, cfg, image_size=None, phase='train'):
|
||||
super(PriorBox, self).__init__()
|
||||
self.min_sizes = cfg['min_sizes']
|
||||
self.steps = cfg['steps']
|
||||
self.clip = cfg['clip']
|
||||
self.image_size = image_size
|
||||
self.feature_maps = [[
|
||||
ceil(self.image_size[0] / step),
|
||||
ceil(self.image_size[1] / step)
|
||||
] for step in self.steps]
|
||||
self.name = 's'
|
||||
|
||||
def forward(self):
|
||||
anchors = []
|
||||
for k, f in enumerate(self.feature_maps):
|
||||
min_sizes = self.min_sizes[k]
|
||||
for i, j in product(range(f[0]), range(f[1])):
|
||||
for min_size in min_sizes:
|
||||
s_kx = min_size / self.image_size[1]
|
||||
s_ky = min_size / self.image_size[0]
|
||||
dense_cx = [
|
||||
x * self.steps[k] / self.image_size[1]
|
||||
for x in [j + 0.5]
|
||||
]
|
||||
dense_cy = [
|
||||
y * self.steps[k] / self.image_size[0]
|
||||
for y in [i + 0.5]
|
||||
]
|
||||
for cy, cx in product(dense_cy, dense_cx):
|
||||
anchors += [cx, cy, s_kx, s_ky]
|
||||
|
||||
# back to torch land
|
||||
output = torch.Tensor(anchors).view(-1, 4)
|
||||
if self.clip:
|
||||
output.clamp_(max=1, min=0)
|
||||
return output
|
||||
|
||||
|
||||
def py_cpu_nms(dets, thresh):
|
||||
"""Pure Python NMS baseline."""
|
||||
x1 = dets[:, 0]
|
||||
y1 = dets[:, 1]
|
||||
x2 = dets[:, 2]
|
||||
y2 = dets[:, 3]
|
||||
scores = dets[:, 4]
|
||||
|
||||
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||
order = scores.argsort()[::-1]
|
||||
|
||||
keep = []
|
||||
while order.size > 0:
|
||||
i = order[0]
|
||||
keep.append(i)
|
||||
xx1 = np.maximum(x1[i], x1[order[1:]])
|
||||
yy1 = np.maximum(y1[i], y1[order[1:]])
|
||||
xx2 = np.minimum(x2[i], x2[order[1:]])
|
||||
yy2 = np.minimum(y2[i], y2[order[1:]])
|
||||
|
||||
w = np.maximum(0.0, xx2 - xx1 + 1)
|
||||
h = np.maximum(0.0, yy2 - yy1 + 1)
|
||||
inter = w * h
|
||||
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
||||
|
||||
inds = np.where(ovr <= thresh)[0]
|
||||
order = order[inds + 1]
|
||||
|
||||
return keep
|
||||
|
||||
|
||||
# Adapted from https://github.com/Hakuyume/chainer-ssd
|
||||
def decode(loc, priors, variances):
|
||||
"""Decode locations from predictions using priors to undo
|
||||
the encoding we did for offset regression at train time.
|
||||
Args:
|
||||
loc (tensor): location predictions for loc layers,
|
||||
Shape: [num_priors,4]
|
||||
priors (tensor): Prior boxes in center-offset form.
|
||||
Shape: [num_priors,4].
|
||||
variances: (list[float]) Variances of priorboxes
|
||||
Return:
|
||||
decoded bounding box predictions
|
||||
"""
|
||||
|
||||
boxes = torch.cat(
|
||||
(priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
|
||||
priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
|
||||
boxes[:, :2] -= boxes[:, 2:] / 2
|
||||
boxes[:, 2:] += boxes[:, :2]
|
||||
return boxes
|
||||
|
||||
|
||||
def decode_landm(pre, priors, variances):
|
||||
"""Decode landm from predictions using priors to undo
|
||||
the encoding we did for offset regression at train time.
|
||||
Args:
|
||||
pre (tensor): landm predictions for loc layers,
|
||||
Shape: [num_priors,10]
|
||||
priors (tensor): Prior boxes in center-offset form.
|
||||
Shape: [num_priors,4].
|
||||
variances: (list[float]) Variances of priorboxes
|
||||
Return:
|
||||
decoded landm predictions
|
||||
"""
|
||||
a = priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:]
|
||||
b = priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:]
|
||||
c = priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:]
|
||||
d = priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:]
|
||||
e = priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:]
|
||||
landms = torch.cat((a, b, c, d, e), dim=1)
|
||||
return landms
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
import os.path as osp
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from threading import Lock
|
||||
from typing import Any, Dict, Generator, List, Mapping, Union
|
||||
|
||||
|
||||
55
modelscope/pipelines/cv/retina_face_detection_pipeline.py
Normal file
55
modelscope/pipelines/cv/retina_face_detection_pipeline.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import os.path as osp
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.cv.face_detection.retinaface import detection
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors import LoadImage
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.face_detection, module_name=Pipelines.retina_face_detection)
|
||||
class RetinaFaceDetectionPipeline(Pipeline):
|
||||
|
||||
def __init__(self, model: str, **kwargs):
|
||||
"""
|
||||
use `model` to create a face detection pipeline for prediction
|
||||
Args:
|
||||
model: model id on modelscope hub.
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
ckpt_path = osp.join(model, ModelFile.TORCH_MODEL_FILE)
|
||||
logger.info(f'loading model from {ckpt_path}')
|
||||
detector = detection.RetinaFaceDetection(
|
||||
model_path=ckpt_path, device=self.device)
|
||||
self.detector = detector
|
||||
logger.info('load model done')
|
||||
|
||||
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||
img = LoadImage.convert_to_ndarray(input)
|
||||
img = img.astype(np.float32)
|
||||
result = {'img': img}
|
||||
return result
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
result = self.detector(input)
|
||||
assert result is not None
|
||||
bboxes = result[0][:, :4].tolist()
|
||||
scores = result[0][:, 4].tolist()
|
||||
lms = result[1].tolist()
|
||||
return {
|
||||
OutputKeys.SCORES: scores,
|
||||
OutputKeys.BOXES: bboxes,
|
||||
OutputKeys.KEYPOINTS: lms,
|
||||
}
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return inputs
|
||||
33
tests/pipelines/test_retina_face_detection.py
Normal file
33
tests/pipelines/test_retina_face_detection.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os.path as osp
|
||||
import unittest
|
||||
|
||||
import cv2
|
||||
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.cv.image_utils import draw_face_detection_result
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class RetinaFaceDetectionTest(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.model_id = 'damo/cv_resnet50_face-detection_retinaface'
|
||||
|
||||
def show_result(self, img_path, detection_result):
|
||||
img = draw_face_detection_result(img_path, detection_result)
|
||||
cv2.imwrite('result.png', img)
|
||||
print(f'output written to {osp.abspath("result.png")}')
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_modelhub(self):
|
||||
face_detection = pipeline(Tasks.face_detection, model=self.model_id)
|
||||
img_path = 'data/test/images/retina_face_detection.jpg'
|
||||
|
||||
result = face_detection(img_path)
|
||||
self.show_result(img_path, result)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user