mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
[to #42322933]add cv-faceDetection and cv-faceRecognition
1. support FaceDetectionPipeline inference
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9470723
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -121,6 +121,7 @@ source.sh
|
|||||||
tensorboard.sh
|
tensorboard.sh
|
||||||
.DS_Store
|
.DS_Store
|
||||||
replace.sh
|
replace.sh
|
||||||
|
result.png
|
||||||
|
|
||||||
# Pytorch
|
# Pytorch
|
||||||
*.pth
|
*.pth
|
||||||
|
|||||||
3
data/test/images/face_detection.png
Normal file
3
data/test/images/face_detection.png
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:aa3963d1c54e6d3d46e9a59872a99ed955d4050092f5cfe5f591e03d740b7042
|
||||||
|
size 653006
|
||||||
3
data/test/images/face_recognition_1.png
Normal file
3
data/test/images/face_recognition_1.png
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:48e541daeb2692907efef47018e41abb5ae6bcd88eb5ff58290d7fe5dc8b2a13
|
||||||
|
size 462584
|
||||||
3
data/test/images/face_recognition_2.png
Normal file
3
data/test/images/face_recognition_2.png
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:e9565b43d9f65361b9bad6553b327c2c6f02fd063a4c8dc0f461e88ea461989d
|
||||||
|
size 357166
|
||||||
@@ -10,6 +10,7 @@ class Models(object):
|
|||||||
Model name should only contain model info but not task info.
|
Model name should only contain model info but not task info.
|
||||||
"""
|
"""
|
||||||
# vision models
|
# vision models
|
||||||
|
scrfd = 'scrfd'
|
||||||
classification_model = 'ClassificationModel'
|
classification_model = 'ClassificationModel'
|
||||||
nafnet = 'nafnet'
|
nafnet = 'nafnet'
|
||||||
csrnet = 'csrnet'
|
csrnet = 'csrnet'
|
||||||
@@ -67,6 +68,7 @@ class Pipelines(object):
|
|||||||
action_recognition = 'TAdaConv_action-recognition'
|
action_recognition = 'TAdaConv_action-recognition'
|
||||||
animal_recognation = 'resnet101-animal_recog'
|
animal_recognation = 'resnet101-animal_recog'
|
||||||
cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding'
|
cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding'
|
||||||
|
face_detection = 'resnet-face-detection-scrfd10gkps'
|
||||||
live_category = 'live-category'
|
live_category = 'live-category'
|
||||||
general_image_classification = 'vit-base_image-classification_ImageNet-labels'
|
general_image_classification = 'vit-base_image-classification_ImageNet-labels'
|
||||||
daily_image_classification = 'vit-base_image-classification_Dailylife-labels'
|
daily_image_classification = 'vit-base_image-classification_Dailylife-labels'
|
||||||
@@ -76,6 +78,7 @@ class Pipelines(object):
|
|||||||
image_super_resolution = 'rrdb-image-super-resolution'
|
image_super_resolution = 'rrdb-image-super-resolution'
|
||||||
face_image_generation = 'gan-face-image-generation'
|
face_image_generation = 'gan-face-image-generation'
|
||||||
style_transfer = 'AAMS-style-transfer'
|
style_transfer = 'AAMS-style-transfer'
|
||||||
|
face_recognition = 'ir101-face-recognition-cfglint'
|
||||||
image_instance_segmentation = 'cascade-mask-rcnn-swin-image-instance-segmentation'
|
image_instance_segmentation = 'cascade-mask-rcnn-swin-image-instance-segmentation'
|
||||||
image2image_translation = 'image-to-image-translation'
|
image2image_translation = 'image-to-image-translation'
|
||||||
live_category = 'live-category'
|
live_category = 'live-category'
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
from . import (action_recognition, animal_recognition, cartoon,
|
from . import (action_recognition, animal_recognition, cartoon,
|
||||||
cmdssl_video_embedding, face_generation, image_classification,
|
cmdssl_video_embedding, face_detection, face_generation,
|
||||||
image_color_enhance, image_colorization, image_denoise,
|
image_classification, image_color_enhance, image_colorization,
|
||||||
image_instance_segmentation, super_resolution, virual_tryon)
|
image_denoise, image_instance_segmentation,
|
||||||
|
image_to_image_translation, super_resolution, virual_tryon)
|
||||||
|
|||||||
0
modelscope/models/cv/face_detection/__init__.py
Normal file
0
modelscope/models/cv/face_detection/__init__.py
Normal file
5
modelscope/models/cv/face_detection/mmdet_patch/__init__.py
Executable file
5
modelscope/models/cv/face_detection/mmdet_patch/__init__.py
Executable file
@@ -0,0 +1,5 @@
|
|||||||
|
"""
|
||||||
|
mmdet_patch is based on
|
||||||
|
https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet,
|
||||||
|
all duplicate functions from official mmdetection are removed.
|
||||||
|
"""
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
from .transforms import bbox2result, distance2kps, kps2distance
|
||||||
|
|
||||||
|
__all__ = ['bbox2result', 'distance2kps', 'kps2distance']
|
||||||
86
modelscope/models/cv/face_detection/mmdet_patch/core/bbox/transforms.py
Executable file
86
modelscope/models/cv/face_detection/mmdet_patch/core/bbox/transforms.py
Executable file
@@ -0,0 +1,86 @@
|
|||||||
|
"""
|
||||||
|
based on https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/core/bbox/transforms.py
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def bbox2result(bboxes, labels, num_classes, kps=None):
|
||||||
|
"""Convert detection results to a list of numpy arrays.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bboxes (torch.Tensor | np.ndarray): shape (n, 5)
|
||||||
|
labels (torch.Tensor | np.ndarray): shape (n, )
|
||||||
|
num_classes (int): class number, including background class
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list(ndarray): bbox results of each class
|
||||||
|
"""
|
||||||
|
bbox_len = 5 if kps is None else 5 + 10 # if has kps, add 10 kps into bbox
|
||||||
|
if bboxes.shape[0] == 0:
|
||||||
|
return [
|
||||||
|
np.zeros((0, bbox_len), dtype=np.float32)
|
||||||
|
for i in range(num_classes)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
if isinstance(bboxes, torch.Tensor):
|
||||||
|
bboxes = bboxes.detach().cpu().numpy()
|
||||||
|
labels = labels.detach().cpu().numpy()
|
||||||
|
if kps is None:
|
||||||
|
return [bboxes[labels == i, :] for i in range(num_classes)]
|
||||||
|
else: # with kps
|
||||||
|
if isinstance(kps, torch.Tensor):
|
||||||
|
kps = kps.detach().cpu().numpy()
|
||||||
|
return [
|
||||||
|
np.hstack([bboxes[labels == i, :], kps[labels == i, :]])
|
||||||
|
for i in range(num_classes)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def distance2kps(points, distance, max_shape=None):
|
||||||
|
"""Decode distance prediction to bounding box.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
points (Tensor): Shape (n, 2), [x, y].
|
||||||
|
distance (Tensor): Distance from the given point to 4
|
||||||
|
boundaries (left, top, right, bottom).
|
||||||
|
max_shape (tuple): Shape of the image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Decoded kps.
|
||||||
|
"""
|
||||||
|
preds = []
|
||||||
|
for i in range(0, distance.shape[1], 2):
|
||||||
|
px = points[:, i % 2] + distance[:, i]
|
||||||
|
py = points[:, i % 2 + 1] + distance[:, i + 1]
|
||||||
|
if max_shape is not None:
|
||||||
|
px = px.clamp(min=0, max=max_shape[1])
|
||||||
|
py = py.clamp(min=0, max=max_shape[0])
|
||||||
|
preds.append(px)
|
||||||
|
preds.append(py)
|
||||||
|
return torch.stack(preds, -1)
|
||||||
|
|
||||||
|
|
||||||
|
def kps2distance(points, kps, max_dis=None, eps=0.1):
|
||||||
|
"""Decode bounding box based on distances.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
points (Tensor): Shape (n, 2), [x, y].
|
||||||
|
kps (Tensor): Shape (n, K), "xyxy" format
|
||||||
|
max_dis (float): Upper bound of the distance.
|
||||||
|
eps (float): a small value to ensure target < max_dis, instead <=
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Decoded distances.
|
||||||
|
"""
|
||||||
|
|
||||||
|
preds = []
|
||||||
|
for i in range(0, kps.shape[1], 2):
|
||||||
|
px = kps[:, i] - points[:, i % 2]
|
||||||
|
py = kps[:, i + 1] - points[:, i % 2 + 1]
|
||||||
|
if max_dis is not None:
|
||||||
|
px = px.clamp(min=0, max=max_dis - eps)
|
||||||
|
py = py.clamp(min=0, max=max_dis - eps)
|
||||||
|
preds.append(px)
|
||||||
|
preds.append(py)
|
||||||
|
return torch.stack(preds, -1)
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
from .bbox_nms import multiclass_nms
|
||||||
|
|
||||||
|
__all__ = ['multiclass_nms']
|
||||||
@@ -0,0 +1,85 @@
|
|||||||
|
"""
|
||||||
|
based on https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/core/post_processing/bbox_nms.py
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def multiclass_nms(multi_bboxes,
|
||||||
|
multi_scores,
|
||||||
|
score_thr,
|
||||||
|
nms_cfg,
|
||||||
|
max_num=-1,
|
||||||
|
score_factors=None,
|
||||||
|
return_inds=False,
|
||||||
|
multi_kps=None):
|
||||||
|
"""NMS for multi-class bboxes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
multi_bboxes (Tensor): shape (n, #class*4) or (n, 4)
|
||||||
|
multi_scores (Tensor): shape (n, #class), where the last column
|
||||||
|
contains scores of the background class, but this will be ignored.
|
||||||
|
score_thr (float): bbox threshold, bboxes with scores lower than it
|
||||||
|
will not be considered.
|
||||||
|
nms_thr (float): NMS IoU threshold
|
||||||
|
max_num (int, optional): if there are more than max_num bboxes after
|
||||||
|
NMS, only top max_num will be kept. Default to -1.
|
||||||
|
score_factors (Tensor, optional): The factors multiplied to scores
|
||||||
|
before applying NMS. Default to None.
|
||||||
|
return_inds (bool, optional): Whether return the indices of kept
|
||||||
|
bboxes. Default to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (bboxes, labels, indices (optional)), tensors of shape (k, 5),
|
||||||
|
(k), and (k). Labels are 0-based.
|
||||||
|
"""
|
||||||
|
num_classes = multi_scores.size(1) - 1
|
||||||
|
# exclude background category
|
||||||
|
kps = None
|
||||||
|
if multi_bboxes.shape[1] > 4:
|
||||||
|
bboxes = multi_bboxes.view(multi_scores.size(0), -1, 4)
|
||||||
|
if multi_kps is not None:
|
||||||
|
kps = multi_kps.view(multi_scores.size(0), -1, 10)
|
||||||
|
else:
|
||||||
|
bboxes = multi_bboxes[:, None].expand(
|
||||||
|
multi_scores.size(0), num_classes, 4)
|
||||||
|
if multi_kps is not None:
|
||||||
|
kps = multi_kps[:, None].expand(
|
||||||
|
multi_scores.size(0), num_classes, 10)
|
||||||
|
|
||||||
|
scores = multi_scores[:, :-1]
|
||||||
|
if score_factors is not None:
|
||||||
|
scores = scores * score_factors[:, None]
|
||||||
|
|
||||||
|
labels = torch.arange(num_classes, dtype=torch.long)
|
||||||
|
labels = labels.view(1, -1).expand_as(scores)
|
||||||
|
|
||||||
|
bboxes = bboxes.reshape(-1, 4)
|
||||||
|
if kps is not None:
|
||||||
|
kps = kps.reshape(-1, 10)
|
||||||
|
scores = scores.reshape(-1)
|
||||||
|
labels = labels.reshape(-1)
|
||||||
|
|
||||||
|
# remove low scoring boxes
|
||||||
|
valid_mask = scores > score_thr
|
||||||
|
inds = valid_mask.nonzero(as_tuple=False).squeeze(1)
|
||||||
|
bboxes, scores, labels = bboxes[inds], scores[inds], labels[inds]
|
||||||
|
if kps is not None:
|
||||||
|
kps = kps[inds]
|
||||||
|
if inds.numel() == 0:
|
||||||
|
if torch.onnx.is_in_onnx_export():
|
||||||
|
raise RuntimeError('[ONNX Error] Can not record NMS '
|
||||||
|
'as it has not been executed this time')
|
||||||
|
return bboxes, labels, kps
|
||||||
|
|
||||||
|
# TODO: add size check before feed into batched_nms
|
||||||
|
from mmcv.ops.nms import batched_nms
|
||||||
|
dets, keep = batched_nms(bboxes, scores, labels, nms_cfg)
|
||||||
|
|
||||||
|
if max_num > 0:
|
||||||
|
dets = dets[:max_num]
|
||||||
|
keep = keep[:max_num]
|
||||||
|
|
||||||
|
if return_inds:
|
||||||
|
return dets, labels[keep], kps[keep], keep
|
||||||
|
else:
|
||||||
|
return dets, labels[keep], kps[keep]
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
from .retinaface import RetinaFaceDataset
|
||||||
|
|
||||||
|
__all__ = ['RetinaFaceDataset']
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
from .transforms import RandomSquareCrop
|
||||||
|
|
||||||
|
__all__ = ['RandomSquareCrop']
|
||||||
188
modelscope/models/cv/face_detection/mmdet_patch/datasets/pipelines/transforms.py
Executable file
188
modelscope/models/cv/face_detection/mmdet_patch/datasets/pipelines/transforms.py
Executable file
@@ -0,0 +1,188 @@
|
|||||||
|
"""
|
||||||
|
based on https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines/transforms.py
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
from mmdet.datasets.builder import PIPELINES
|
||||||
|
from numpy import random
|
||||||
|
|
||||||
|
|
||||||
|
@PIPELINES.register_module()
|
||||||
|
class RandomSquareCrop(object):
|
||||||
|
"""Random crop the image & bboxes, the cropped patches have minimum IoU
|
||||||
|
requirement with original image & bboxes, the IoU threshold is randomly
|
||||||
|
selected from min_ious.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
min_ious (tuple): minimum IoU threshold for all intersections with
|
||||||
|
bounding boxes
|
||||||
|
min_crop_size (float): minimum crop's size (i.e. h,w := a*h, a*w,
|
||||||
|
where a >= min_crop_size).
|
||||||
|
|
||||||
|
Note:
|
||||||
|
The keys for bboxes, labels and masks should be paired. That is, \
|
||||||
|
`gt_bboxes` corresponds to `gt_labels` and `gt_masks`, and \
|
||||||
|
`gt_bboxes_ignore` to `gt_labels_ignore` and `gt_masks_ignore`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
crop_ratio_range=None,
|
||||||
|
crop_choice=None,
|
||||||
|
bbox_clip_border=True):
|
||||||
|
|
||||||
|
self.crop_ratio_range = crop_ratio_range
|
||||||
|
self.crop_choice = crop_choice
|
||||||
|
self.bbox_clip_border = bbox_clip_border
|
||||||
|
|
||||||
|
assert (self.crop_ratio_range is None) ^ (self.crop_choice is None)
|
||||||
|
if self.crop_ratio_range is not None:
|
||||||
|
self.crop_ratio_min, self.crop_ratio_max = self.crop_ratio_range
|
||||||
|
|
||||||
|
self.bbox2label = {
|
||||||
|
'gt_bboxes': 'gt_labels',
|
||||||
|
'gt_bboxes_ignore': 'gt_labels_ignore'
|
||||||
|
}
|
||||||
|
self.bbox2mask = {
|
||||||
|
'gt_bboxes': 'gt_masks',
|
||||||
|
'gt_bboxes_ignore': 'gt_masks_ignore'
|
||||||
|
}
|
||||||
|
|
||||||
|
def __call__(self, results):
|
||||||
|
"""Call function to crop images and bounding boxes with minimum IoU
|
||||||
|
constraint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results (dict): Result dict from loading pipeline.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Result dict with images and bounding boxes cropped, \
|
||||||
|
'img_shape' key is updated.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if 'img_fields' in results:
|
||||||
|
assert results['img_fields'] == ['img'], \
|
||||||
|
'Only single img_fields is allowed'
|
||||||
|
img = results['img']
|
||||||
|
assert 'bbox_fields' in results
|
||||||
|
assert 'gt_bboxes' in results
|
||||||
|
boxes = results['gt_bboxes']
|
||||||
|
h, w, c = img.shape
|
||||||
|
scale_retry = 0
|
||||||
|
if self.crop_ratio_range is not None:
|
||||||
|
max_scale = self.crop_ratio_max
|
||||||
|
else:
|
||||||
|
max_scale = np.amax(self.crop_choice)
|
||||||
|
while True:
|
||||||
|
scale_retry += 1
|
||||||
|
|
||||||
|
if scale_retry == 1 or max_scale > 1.0:
|
||||||
|
if self.crop_ratio_range is not None:
|
||||||
|
scale = np.random.uniform(self.crop_ratio_min,
|
||||||
|
self.crop_ratio_max)
|
||||||
|
elif self.crop_choice is not None:
|
||||||
|
scale = np.random.choice(self.crop_choice)
|
||||||
|
else:
|
||||||
|
scale = scale * 1.2
|
||||||
|
|
||||||
|
for i in range(250):
|
||||||
|
short_side = min(w, h)
|
||||||
|
cw = int(scale * short_side)
|
||||||
|
ch = cw
|
||||||
|
|
||||||
|
# TODO +1
|
||||||
|
if w == cw:
|
||||||
|
left = 0
|
||||||
|
elif w > cw:
|
||||||
|
left = random.randint(0, w - cw)
|
||||||
|
else:
|
||||||
|
left = random.randint(w - cw, 0)
|
||||||
|
if h == ch:
|
||||||
|
top = 0
|
||||||
|
elif h > ch:
|
||||||
|
top = random.randint(0, h - ch)
|
||||||
|
else:
|
||||||
|
top = random.randint(h - ch, 0)
|
||||||
|
|
||||||
|
patch = np.array(
|
||||||
|
(int(left), int(top), int(left + cw), int(top + ch)),
|
||||||
|
dtype=np.int)
|
||||||
|
|
||||||
|
# center of boxes should inside the crop img
|
||||||
|
# only adjust boxes and instance masks when the gt is not empty
|
||||||
|
# adjust boxes
|
||||||
|
def is_center_of_bboxes_in_patch(boxes, patch):
|
||||||
|
# TODO >=
|
||||||
|
center = (boxes[:, :2] + boxes[:, 2:]) / 2
|
||||||
|
mask = \
|
||||||
|
((center[:, 0] > patch[0])
|
||||||
|
* (center[:, 1] > patch[1])
|
||||||
|
* (center[:, 0] < patch[2])
|
||||||
|
* (center[:, 1] < patch[3]))
|
||||||
|
return mask
|
||||||
|
|
||||||
|
mask = is_center_of_bboxes_in_patch(boxes, patch)
|
||||||
|
if not mask.any():
|
||||||
|
continue
|
||||||
|
for key in results.get('bbox_fields', []):
|
||||||
|
boxes = results[key].copy()
|
||||||
|
mask = is_center_of_bboxes_in_patch(boxes, patch)
|
||||||
|
boxes = boxes[mask]
|
||||||
|
if self.bbox_clip_border:
|
||||||
|
boxes[:, 2:] = boxes[:, 2:].clip(max=patch[2:])
|
||||||
|
boxes[:, :2] = boxes[:, :2].clip(min=patch[:2])
|
||||||
|
boxes -= np.tile(patch[:2], 2)
|
||||||
|
|
||||||
|
results[key] = boxes
|
||||||
|
# labels
|
||||||
|
label_key = self.bbox2label.get(key)
|
||||||
|
if label_key in results:
|
||||||
|
results[label_key] = results[label_key][mask]
|
||||||
|
|
||||||
|
# keypoints field
|
||||||
|
if key == 'gt_bboxes':
|
||||||
|
for kps_key in results.get('keypoints_fields', []):
|
||||||
|
keypointss = results[kps_key].copy()
|
||||||
|
keypointss = keypointss[mask, :, :]
|
||||||
|
if self.bbox_clip_border:
|
||||||
|
keypointss[:, :, :
|
||||||
|
2] = keypointss[:, :, :2].clip(
|
||||||
|
max=patch[2:])
|
||||||
|
keypointss[:, :, :
|
||||||
|
2] = keypointss[:, :, :2].clip(
|
||||||
|
min=patch[:2])
|
||||||
|
keypointss[:, :, 0] -= patch[0]
|
||||||
|
keypointss[:, :, 1] -= patch[1]
|
||||||
|
results[kps_key] = keypointss
|
||||||
|
|
||||||
|
# mask fields
|
||||||
|
mask_key = self.bbox2mask.get(key)
|
||||||
|
if mask_key in results:
|
||||||
|
results[mask_key] = results[mask_key][mask.nonzero()
|
||||||
|
[0]].crop(patch)
|
||||||
|
|
||||||
|
# adjust the img no matter whether the gt is empty before crop
|
||||||
|
rimg = np.ones((ch, cw, 3), dtype=img.dtype) * 128
|
||||||
|
patch_from = patch.copy()
|
||||||
|
patch_from[0] = max(0, patch_from[0])
|
||||||
|
patch_from[1] = max(0, patch_from[1])
|
||||||
|
patch_from[2] = min(img.shape[1], patch_from[2])
|
||||||
|
patch_from[3] = min(img.shape[0], patch_from[3])
|
||||||
|
patch_to = patch.copy()
|
||||||
|
patch_to[0] = max(0, patch_to[0] * -1)
|
||||||
|
patch_to[1] = max(0, patch_to[1] * -1)
|
||||||
|
patch_to[2] = patch_to[0] + (patch_from[2] - patch_from[0])
|
||||||
|
patch_to[3] = patch_to[1] + (patch_from[3] - patch_from[1])
|
||||||
|
rimg[patch_to[1]:patch_to[3],
|
||||||
|
patch_to[0]:patch_to[2], :] = img[
|
||||||
|
patch_from[1]:patch_from[3],
|
||||||
|
patch_from[0]:patch_from[2], :]
|
||||||
|
img = rimg
|
||||||
|
results['img'] = img
|
||||||
|
results['img_shape'] = img.shape
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
repr_str = self.__class__.__name__
|
||||||
|
repr_str += f'(min_ious={self.min_iou}, '
|
||||||
|
repr_str += f'crop_size={self.crop_size})'
|
||||||
|
return repr_str
|
||||||
151
modelscope/models/cv/face_detection/mmdet_patch/datasets/retinaface.py
Executable file
151
modelscope/models/cv/face_detection/mmdet_patch/datasets/retinaface.py
Executable file
@@ -0,0 +1,151 @@
|
|||||||
|
"""
|
||||||
|
based on https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/retinaface.py
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
from mmdet.datasets.builder import DATASETS
|
||||||
|
from mmdet.datasets.custom import CustomDataset
|
||||||
|
|
||||||
|
|
||||||
|
@DATASETS.register_module()
|
||||||
|
class RetinaFaceDataset(CustomDataset):
|
||||||
|
|
||||||
|
CLASSES = ('FG', )
|
||||||
|
|
||||||
|
def __init__(self, min_size=None, **kwargs):
|
||||||
|
self.NK = 5
|
||||||
|
self.cat2label = {cat: i for i, cat in enumerate(self.CLASSES)}
|
||||||
|
self.min_size = min_size
|
||||||
|
self.gt_path = kwargs.get('gt_path')
|
||||||
|
super(RetinaFaceDataset, self).__init__(**kwargs)
|
||||||
|
|
||||||
|
def _parse_ann_line(self, line):
|
||||||
|
values = [float(x) for x in line.strip().split()]
|
||||||
|
bbox = np.array(values[0:4], dtype=np.float32)
|
||||||
|
kps = np.zeros((self.NK, 3), dtype=np.float32)
|
||||||
|
ignore = False
|
||||||
|
if self.min_size is not None:
|
||||||
|
assert not self.test_mode
|
||||||
|
w = bbox[2] - bbox[0]
|
||||||
|
h = bbox[3] - bbox[1]
|
||||||
|
if w < self.min_size or h < self.min_size:
|
||||||
|
ignore = True
|
||||||
|
if len(values) > 4:
|
||||||
|
if len(values) > 5:
|
||||||
|
kps = np.array(
|
||||||
|
values[4:19], dtype=np.float32).reshape((self.NK, 3))
|
||||||
|
for li in range(kps.shape[0]):
|
||||||
|
if (kps[li, :] == -1).all():
|
||||||
|
kps[li][2] = 0.0 # weight = 0, ignore
|
||||||
|
else:
|
||||||
|
assert kps[li][2] >= 0
|
||||||
|
kps[li][2] = 1.0 # weight
|
||||||
|
else: # len(values)==5
|
||||||
|
if not ignore:
|
||||||
|
ignore = (values[4] == 1)
|
||||||
|
else:
|
||||||
|
assert self.test_mode
|
||||||
|
|
||||||
|
return dict(bbox=bbox, kps=kps, ignore=ignore, cat='FG')
|
||||||
|
|
||||||
|
def load_annotations(self, ann_file):
|
||||||
|
"""Load annotation from COCO style annotation file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ann_file (str): Path of annotation file.
|
||||||
|
20220711@tyx: ann_file is list of img paths is supported
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[dict]: Annotation info from COCO api.
|
||||||
|
"""
|
||||||
|
if isinstance(ann_file, list):
|
||||||
|
data_infos = []
|
||||||
|
for line in ann_file:
|
||||||
|
name = line
|
||||||
|
objs = [0, 0, 0, 0]
|
||||||
|
data_infos.append(
|
||||||
|
dict(filename=name, width=0, height=0, objs=objs))
|
||||||
|
else:
|
||||||
|
name = None
|
||||||
|
bbox_map = {}
|
||||||
|
for line in open(ann_file, 'r'):
|
||||||
|
line = line.strip()
|
||||||
|
if line.startswith('#'):
|
||||||
|
value = line[1:].strip().split()
|
||||||
|
name = value[0]
|
||||||
|
width = int(value[1])
|
||||||
|
height = int(value[2])
|
||||||
|
|
||||||
|
bbox_map[name] = dict(width=width, height=height, objs=[])
|
||||||
|
continue
|
||||||
|
assert name is not None
|
||||||
|
assert name in bbox_map
|
||||||
|
bbox_map[name]['objs'].append(line)
|
||||||
|
print('origin image size', len(bbox_map))
|
||||||
|
data_infos = []
|
||||||
|
for name in bbox_map:
|
||||||
|
item = bbox_map[name]
|
||||||
|
width = item['width']
|
||||||
|
height = item['height']
|
||||||
|
vals = item['objs']
|
||||||
|
objs = []
|
||||||
|
for line in vals:
|
||||||
|
data = self._parse_ann_line(line)
|
||||||
|
if data is None:
|
||||||
|
continue
|
||||||
|
objs.append(data) # data is (bbox, kps, cat)
|
||||||
|
if len(objs) == 0 and not self.test_mode:
|
||||||
|
continue
|
||||||
|
data_infos.append(
|
||||||
|
dict(filename=name, width=width, height=height, objs=objs))
|
||||||
|
return data_infos
|
||||||
|
|
||||||
|
def get_ann_info(self, idx):
|
||||||
|
"""Get COCO annotation by index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idx (int): Index of data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Annotation info of specified index.
|
||||||
|
"""
|
||||||
|
data_info = self.data_infos[idx]
|
||||||
|
|
||||||
|
bboxes = []
|
||||||
|
keypointss = []
|
||||||
|
labels = []
|
||||||
|
bboxes_ignore = []
|
||||||
|
labels_ignore = []
|
||||||
|
for obj in data_info['objs']:
|
||||||
|
label = self.cat2label[obj['cat']]
|
||||||
|
bbox = obj['bbox']
|
||||||
|
keypoints = obj['kps']
|
||||||
|
ignore = obj['ignore']
|
||||||
|
if ignore:
|
||||||
|
bboxes_ignore.append(bbox)
|
||||||
|
labels_ignore.append(label)
|
||||||
|
else:
|
||||||
|
bboxes.append(bbox)
|
||||||
|
labels.append(label)
|
||||||
|
keypointss.append(keypoints)
|
||||||
|
if not bboxes:
|
||||||
|
bboxes = np.zeros((0, 4))
|
||||||
|
labels = np.zeros((0, ))
|
||||||
|
keypointss = np.zeros((0, self.NK, 3))
|
||||||
|
else:
|
||||||
|
# bboxes = np.array(bboxes, ndmin=2) - 1
|
||||||
|
bboxes = np.array(bboxes, ndmin=2)
|
||||||
|
labels = np.array(labels)
|
||||||
|
keypointss = np.array(keypointss, ndmin=3)
|
||||||
|
if not bboxes_ignore:
|
||||||
|
bboxes_ignore = np.zeros((0, 4))
|
||||||
|
labels_ignore = np.zeros((0, ))
|
||||||
|
else:
|
||||||
|
bboxes_ignore = np.array(bboxes_ignore, ndmin=2)
|
||||||
|
labels_ignore = np.array(labels_ignore)
|
||||||
|
ann = dict(
|
||||||
|
bboxes=bboxes.astype(np.float32),
|
||||||
|
labels=labels.astype(np.int64),
|
||||||
|
keypointss=keypointss.astype(np.float32),
|
||||||
|
bboxes_ignore=bboxes_ignore.astype(np.float32),
|
||||||
|
labels_ignore=labels_ignore.astype(np.int64))
|
||||||
|
return ann
|
||||||
2
modelscope/models/cv/face_detection/mmdet_patch/models/__init__.py
Executable file
2
modelscope/models/cv/face_detection/mmdet_patch/models/__init__.py
Executable file
@@ -0,0 +1,2 @@
|
|||||||
|
from .dense_heads import * # noqa: F401,F403
|
||||||
|
from .detectors import * # noqa: F401,F403
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
from .resnet import ResNetV1e
|
||||||
|
|
||||||
|
__all__ = ['ResNetV1e']
|
||||||
@@ -0,0 +1,412 @@
|
|||||||
|
"""
|
||||||
|
based on https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/models/backbones/resnet.py
|
||||||
|
"""
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.utils.checkpoint as cp
|
||||||
|
from mmcv.cnn import (build_conv_layer, build_norm_layer, build_plugin_layer,
|
||||||
|
constant_init, kaiming_init)
|
||||||
|
from mmcv.runner import load_checkpoint
|
||||||
|
from mmdet.models.backbones.resnet import BasicBlock, Bottleneck
|
||||||
|
from mmdet.models.builder import BACKBONES
|
||||||
|
from mmdet.models.utils import ResLayer
|
||||||
|
from mmdet.utils import get_root_logger
|
||||||
|
from torch.nn.modules.batchnorm import _BatchNorm
|
||||||
|
|
||||||
|
|
||||||
|
class ResNet(nn.Module):
|
||||||
|
"""ResNet backbone.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
|
||||||
|
stem_channels (int | None): Number of stem channels. If not specified,
|
||||||
|
it will be the same as `base_channels`. Default: None.
|
||||||
|
base_channels (int): Number of base channels of res layer. Default: 64.
|
||||||
|
in_channels (int): Number of input image channels. Default: 3.
|
||||||
|
num_stages (int): Resnet stages. Default: 4.
|
||||||
|
strides (Sequence[int]): Strides of the first block of each stage.
|
||||||
|
dilations (Sequence[int]): Dilation of each stage.
|
||||||
|
out_indices (Sequence[int]): Output from which stages.
|
||||||
|
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
|
||||||
|
layer is the 3x3 conv layer, otherwise the stride-two layer is
|
||||||
|
the first 1x1 conv layer.
|
||||||
|
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv
|
||||||
|
avg_down (bool): Use AvgPool instead of stride conv when
|
||||||
|
downsampling in the bottleneck.
|
||||||
|
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
||||||
|
-1 means not freezing any parameters.
|
||||||
|
norm_cfg (dict): Dictionary to construct and config norm layer.
|
||||||
|
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||||
|
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||||
|
and its variants only.
|
||||||
|
plugins (list[dict]): List of plugins for stages, each dict contains:
|
||||||
|
|
||||||
|
- cfg (dict, required): Cfg dict to build plugin.
|
||||||
|
- position (str, required): Position inside block to insert
|
||||||
|
plugin, options are 'after_conv1', 'after_conv2', 'after_conv3'.
|
||||||
|
- stages (tuple[bool], optional): Stages to apply plugin, length
|
||||||
|
should be same as 'num_stages'.
|
||||||
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||||
|
memory while slowing down the training speed.
|
||||||
|
zero_init_residual (bool): Whether to use zero init for last norm layer
|
||||||
|
in resblocks to let them behave as identity.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from mmdet.models import ResNet
|
||||||
|
>>> import torch
|
||||||
|
>>> self = ResNet(depth=18)
|
||||||
|
>>> self.eval()
|
||||||
|
>>> inputs = torch.rand(1, 3, 32, 32)
|
||||||
|
>>> level_outputs = self.forward(inputs)
|
||||||
|
>>> for level_out in level_outputs:
|
||||||
|
... print(tuple(level_out.shape))
|
||||||
|
(1, 64, 8, 8)
|
||||||
|
(1, 128, 4, 4)
|
||||||
|
(1, 256, 2, 2)
|
||||||
|
(1, 512, 1, 1)
|
||||||
|
"""
|
||||||
|
|
||||||
|
arch_settings = {
|
||||||
|
0: (BasicBlock, (2, 2, 2, 2)),
|
||||||
|
18: (BasicBlock, (2, 2, 2, 2)),
|
||||||
|
19: (BasicBlock, (2, 4, 4, 1)),
|
||||||
|
20: (BasicBlock, (2, 3, 2, 2)),
|
||||||
|
22: (BasicBlock, (2, 4, 3, 1)),
|
||||||
|
24: (BasicBlock, (2, 4, 4, 1)),
|
||||||
|
26: (BasicBlock, (2, 4, 4, 2)),
|
||||||
|
28: (BasicBlock, (2, 5, 4, 2)),
|
||||||
|
29: (BasicBlock, (2, 6, 3, 2)),
|
||||||
|
30: (BasicBlock, (2, 5, 5, 2)),
|
||||||
|
32: (BasicBlock, (2, 6, 5, 2)),
|
||||||
|
34: (BasicBlock, (3, 4, 6, 3)),
|
||||||
|
35: (BasicBlock, (3, 6, 4, 3)),
|
||||||
|
38: (BasicBlock, (3, 8, 4, 3)),
|
||||||
|
40: (BasicBlock, (3, 8, 5, 3)),
|
||||||
|
50: (Bottleneck, (3, 4, 6, 3)),
|
||||||
|
56: (Bottleneck, (3, 8, 4, 3)),
|
||||||
|
68: (Bottleneck, (3, 10, 6, 3)),
|
||||||
|
74: (Bottleneck, (3, 12, 6, 3)),
|
||||||
|
101: (Bottleneck, (3, 4, 23, 3)),
|
||||||
|
152: (Bottleneck, (3, 8, 36, 3))
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
depth,
|
||||||
|
in_channels=3,
|
||||||
|
stem_channels=None,
|
||||||
|
base_channels=64,
|
||||||
|
num_stages=4,
|
||||||
|
block_cfg=None,
|
||||||
|
strides=(1, 2, 2, 2),
|
||||||
|
dilations=(1, 1, 1, 1),
|
||||||
|
out_indices=(0, 1, 2, 3),
|
||||||
|
style='pytorch',
|
||||||
|
deep_stem=False,
|
||||||
|
avg_down=False,
|
||||||
|
no_pool33=False,
|
||||||
|
frozen_stages=-1,
|
||||||
|
conv_cfg=None,
|
||||||
|
norm_cfg=dict(type='BN', requires_grad=True),
|
||||||
|
norm_eval=True,
|
||||||
|
dcn=None,
|
||||||
|
stage_with_dcn=(False, False, False, False),
|
||||||
|
plugins=None,
|
||||||
|
with_cp=False,
|
||||||
|
zero_init_residual=True):
|
||||||
|
super(ResNet, self).__init__()
|
||||||
|
if depth not in self.arch_settings:
|
||||||
|
raise KeyError(f'invalid depth {depth} for resnet')
|
||||||
|
self.depth = depth
|
||||||
|
if stem_channels is None:
|
||||||
|
stem_channels = base_channels
|
||||||
|
self.stem_channels = stem_channels
|
||||||
|
self.base_channels = base_channels
|
||||||
|
self.num_stages = num_stages
|
||||||
|
assert num_stages >= 1 and num_stages <= 4
|
||||||
|
self.strides = strides
|
||||||
|
self.dilations = dilations
|
||||||
|
assert len(strides) == len(dilations) == num_stages
|
||||||
|
self.out_indices = out_indices
|
||||||
|
assert max(out_indices) < num_stages
|
||||||
|
self.style = style
|
||||||
|
self.deep_stem = deep_stem
|
||||||
|
self.avg_down = avg_down
|
||||||
|
self.no_pool33 = no_pool33
|
||||||
|
self.frozen_stages = frozen_stages
|
||||||
|
self.conv_cfg = conv_cfg
|
||||||
|
self.norm_cfg = norm_cfg
|
||||||
|
self.with_cp = with_cp
|
||||||
|
self.norm_eval = norm_eval
|
||||||
|
self.dcn = dcn
|
||||||
|
self.stage_with_dcn = stage_with_dcn
|
||||||
|
if dcn is not None:
|
||||||
|
assert len(stage_with_dcn) == num_stages
|
||||||
|
self.plugins = plugins
|
||||||
|
self.zero_init_residual = zero_init_residual
|
||||||
|
if block_cfg is None:
|
||||||
|
self.block, stage_blocks = self.arch_settings[depth]
|
||||||
|
else:
|
||||||
|
self.block = BasicBlock if block_cfg[
|
||||||
|
'block'] == 'BasicBlock' else Bottleneck
|
||||||
|
stage_blocks = block_cfg['stage_blocks']
|
||||||
|
assert len(stage_blocks) >= num_stages
|
||||||
|
self.stage_blocks = stage_blocks[:num_stages]
|
||||||
|
self.inplanes = stem_channels
|
||||||
|
|
||||||
|
self._make_stem_layer(in_channels, stem_channels)
|
||||||
|
if block_cfg is not None and 'stage_planes' in block_cfg:
|
||||||
|
stage_planes = block_cfg['stage_planes']
|
||||||
|
else:
|
||||||
|
stage_planes = [base_channels * 2**i for i in range(num_stages)]
|
||||||
|
|
||||||
|
# print('resnet cfg:', stage_blocks, stage_planes)
|
||||||
|
self.res_layers = []
|
||||||
|
for i, num_blocks in enumerate(self.stage_blocks):
|
||||||
|
stride = strides[i]
|
||||||
|
dilation = dilations[i]
|
||||||
|
dcn = self.dcn if self.stage_with_dcn[i] else None
|
||||||
|
if plugins is not None:
|
||||||
|
stage_plugins = self.make_stage_plugins(plugins, i)
|
||||||
|
else:
|
||||||
|
stage_plugins = None
|
||||||
|
planes = stage_planes[i]
|
||||||
|
res_layer = self.make_res_layer(
|
||||||
|
block=self.block,
|
||||||
|
inplanes=self.inplanes,
|
||||||
|
planes=planes,
|
||||||
|
num_blocks=num_blocks,
|
||||||
|
stride=stride,
|
||||||
|
dilation=dilation,
|
||||||
|
style=self.style,
|
||||||
|
avg_down=self.avg_down,
|
||||||
|
with_cp=with_cp,
|
||||||
|
conv_cfg=conv_cfg,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
dcn=dcn,
|
||||||
|
plugins=stage_plugins)
|
||||||
|
self.inplanes = planes * self.block.expansion
|
||||||
|
layer_name = f'layer{i + 1}'
|
||||||
|
self.add_module(layer_name, res_layer)
|
||||||
|
self.res_layers.append(layer_name)
|
||||||
|
|
||||||
|
self._freeze_stages()
|
||||||
|
|
||||||
|
self.feat_dim = self.block.expansion * base_channels * 2**(
|
||||||
|
len(self.stage_blocks) - 1)
|
||||||
|
|
||||||
|
def make_stage_plugins(self, plugins, stage_idx):
|
||||||
|
"""Make plugins for ResNet ``stage_idx`` th stage.
|
||||||
|
|
||||||
|
Currently we support to insert ``context_block``,
|
||||||
|
``empirical_attention_block``, ``nonlocal_block`` into the backbone
|
||||||
|
like ResNet/ResNeXt. They could be inserted after conv1/conv2/conv3 of
|
||||||
|
Bottleneck.
|
||||||
|
|
||||||
|
An example of plugins format could be:
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> plugins=[
|
||||||
|
... dict(cfg=dict(type='xxx', arg1='xxx'),
|
||||||
|
... stages=(False, True, True, True),
|
||||||
|
... position='after_conv2'),
|
||||||
|
... dict(cfg=dict(type='yyy'),
|
||||||
|
... stages=(True, True, True, True),
|
||||||
|
... position='after_conv3'),
|
||||||
|
... dict(cfg=dict(type='zzz', postfix='1'),
|
||||||
|
... stages=(True, True, True, True),
|
||||||
|
... position='after_conv3'),
|
||||||
|
... dict(cfg=dict(type='zzz', postfix='2'),
|
||||||
|
... stages=(True, True, True, True),
|
||||||
|
... position='after_conv3')
|
||||||
|
... ]
|
||||||
|
>>> self = ResNet(depth=18)
|
||||||
|
>>> stage_plugins = self.make_stage_plugins(plugins, 0)
|
||||||
|
>>> assert len(stage_plugins) == 3
|
||||||
|
|
||||||
|
Suppose ``stage_idx=0``, the structure of blocks in the stage would be:
|
||||||
|
|
||||||
|
.. code-block:: none
|
||||||
|
|
||||||
|
conv1-> conv2->conv3->yyy->zzz1->zzz2
|
||||||
|
|
||||||
|
Suppose 'stage_idx=1', the structure of blocks in the stage would be:
|
||||||
|
|
||||||
|
.. code-block:: none
|
||||||
|
|
||||||
|
conv1-> conv2->xxx->conv3->yyy->zzz1->zzz2
|
||||||
|
|
||||||
|
If stages is missing, the plugin would be applied to all stages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugins (list[dict]): List of plugins cfg to build. The postfix is
|
||||||
|
required if multiple same type plugins are inserted.
|
||||||
|
stage_idx (int): Index of stage to build
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[dict]: Plugins for current stage
|
||||||
|
"""
|
||||||
|
stage_plugins = []
|
||||||
|
for plugin in plugins:
|
||||||
|
plugin = plugin.copy()
|
||||||
|
stages = plugin.pop('stages', None)
|
||||||
|
assert stages is None or len(stages) == self.num_stages
|
||||||
|
# whether to insert plugin into current stage
|
||||||
|
if stages is None or stages[stage_idx]:
|
||||||
|
stage_plugins.append(plugin)
|
||||||
|
|
||||||
|
return stage_plugins
|
||||||
|
|
||||||
|
def make_res_layer(self, **kwargs):
|
||||||
|
"""Pack all blocks in a stage into a ``ResLayer``."""
|
||||||
|
return ResLayer(**kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def norm1(self):
|
||||||
|
"""nn.Module: the normalization layer named "norm1" """
|
||||||
|
return getattr(self, self.norm1_name)
|
||||||
|
|
||||||
|
def _make_stem_layer(self, in_channels, stem_channels):
|
||||||
|
if self.deep_stem:
|
||||||
|
self.stem = nn.Sequential(
|
||||||
|
build_conv_layer(
|
||||||
|
self.conv_cfg,
|
||||||
|
in_channels,
|
||||||
|
stem_channels // 2,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
bias=False),
|
||||||
|
build_norm_layer(self.norm_cfg, stem_channels // 2)[1],
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
build_conv_layer(
|
||||||
|
self.conv_cfg,
|
||||||
|
stem_channels // 2,
|
||||||
|
stem_channels // 2,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
bias=False),
|
||||||
|
build_norm_layer(self.norm_cfg, stem_channels // 2)[1],
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
build_conv_layer(
|
||||||
|
self.conv_cfg,
|
||||||
|
stem_channels // 2,
|
||||||
|
stem_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
bias=False),
|
||||||
|
build_norm_layer(self.norm_cfg, stem_channels)[1],
|
||||||
|
nn.ReLU(inplace=True))
|
||||||
|
else:
|
||||||
|
self.conv1 = build_conv_layer(
|
||||||
|
self.conv_cfg,
|
||||||
|
in_channels,
|
||||||
|
stem_channels,
|
||||||
|
kernel_size=7,
|
||||||
|
stride=2,
|
||||||
|
padding=3,
|
||||||
|
bias=False)
|
||||||
|
self.norm1_name, norm1 = build_norm_layer(
|
||||||
|
self.norm_cfg, stem_channels, postfix=1)
|
||||||
|
self.add_module(self.norm1_name, norm1)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
if self.no_pool33:
|
||||||
|
assert self.deep_stem
|
||||||
|
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
|
||||||
|
else:
|
||||||
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||||
|
|
||||||
|
def _freeze_stages(self):
|
||||||
|
if self.frozen_stages >= 0:
|
||||||
|
if self.deep_stem:
|
||||||
|
self.stem.eval()
|
||||||
|
for param in self.stem.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
else:
|
||||||
|
self.norm1.eval()
|
||||||
|
for m in [self.conv1, self.norm1]:
|
||||||
|
for param in m.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
for i in range(1, self.frozen_stages + 1):
|
||||||
|
m = getattr(self, f'layer{i}')
|
||||||
|
m.eval()
|
||||||
|
for param in m.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def init_weights(self, pretrained=None):
|
||||||
|
"""Initialize the weights in backbone.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (str, optional): Path to pre-trained weights.
|
||||||
|
Defaults to None.
|
||||||
|
"""
|
||||||
|
if isinstance(pretrained, str):
|
||||||
|
logger = get_root_logger()
|
||||||
|
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
||||||
|
elif pretrained is None:
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
kaiming_init(m)
|
||||||
|
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
||||||
|
constant_init(m, 1)
|
||||||
|
|
||||||
|
if self.dcn is not None:
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, Bottleneck) and hasattr(
|
||||||
|
m.conv2, 'conv_offset'):
|
||||||
|
constant_init(m.conv2.conv_offset, 0)
|
||||||
|
|
||||||
|
if self.zero_init_residual:
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, Bottleneck):
|
||||||
|
constant_init(m.norm3, 0)
|
||||||
|
elif isinstance(m, BasicBlock):
|
||||||
|
constant_init(m.norm2, 0)
|
||||||
|
else:
|
||||||
|
raise TypeError('pretrained must be a str or None')
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Forward function."""
|
||||||
|
if self.deep_stem:
|
||||||
|
x = self.stem(x)
|
||||||
|
else:
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.norm1(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.maxpool(x)
|
||||||
|
outs = []
|
||||||
|
for i, layer_name in enumerate(self.res_layers):
|
||||||
|
res_layer = getattr(self, layer_name)
|
||||||
|
x = res_layer(x)
|
||||||
|
if i in self.out_indices:
|
||||||
|
outs.append(x)
|
||||||
|
return tuple(outs)
|
||||||
|
|
||||||
|
def train(self, mode=True):
|
||||||
|
"""Convert the model into training mode while keep normalization layer
|
||||||
|
freezed."""
|
||||||
|
super(ResNet, self).train(mode)
|
||||||
|
self._freeze_stages()
|
||||||
|
if mode and self.norm_eval:
|
||||||
|
for m in self.modules():
|
||||||
|
# trick: eval have effect on BatchNorm only
|
||||||
|
if isinstance(m, _BatchNorm):
|
||||||
|
m.eval()
|
||||||
|
|
||||||
|
|
||||||
|
@BACKBONES.register_module()
|
||||||
|
class ResNetV1e(ResNet):
|
||||||
|
r"""ResNetV1d variant described in `Bag of Tricks
|
||||||
|
<https://arxiv.org/pdf/1812.01187.pdf>`_.
|
||||||
|
|
||||||
|
Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in
|
||||||
|
the input stem with three 3x3 convs. And in the downsampling block, a 2x2
|
||||||
|
avg_pool with stride 2 is added before conv, whose stride is changed to 1.
|
||||||
|
|
||||||
|
Compared with ResNetV1d, ResNetV1e change maxpooling from 3x3 to 2x2 pad=1
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super(ResNetV1e, self).__init__(
|
||||||
|
deep_stem=True, avg_down=True, no_pool33=True, **kwargs)
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
from .scrfd_head import SCRFDHead
|
||||||
|
|
||||||
|
__all__ = ['SCRFDHead']
|
||||||
1068
modelscope/models/cv/face_detection/mmdet_patch/models/dense_heads/scrfd_head.py
Executable file
1068
modelscope/models/cv/face_detection/mmdet_patch/models/dense_heads/scrfd_head.py
Executable file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,3 @@
|
|||||||
|
from .scrfd import SCRFD
|
||||||
|
|
||||||
|
__all__ = ['SCRFD']
|
||||||
109
modelscope/models/cv/face_detection/mmdet_patch/models/detectors/scrfd.py
Executable file
109
modelscope/models/cv/face_detection/mmdet_patch/models/detectors/scrfd.py
Executable file
@@ -0,0 +1,109 @@
|
|||||||
|
"""
|
||||||
|
based on https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/models/detectors/scrfd.py
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
from mmdet.models.builder import DETECTORS
|
||||||
|
from mmdet.models.detectors.single_stage import SingleStageDetector
|
||||||
|
|
||||||
|
from ....mmdet_patch.core.bbox import bbox2result
|
||||||
|
|
||||||
|
|
||||||
|
@DETECTORS.register_module()
|
||||||
|
class SCRFD(SingleStageDetector):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
backbone,
|
||||||
|
neck,
|
||||||
|
bbox_head,
|
||||||
|
train_cfg=None,
|
||||||
|
test_cfg=None,
|
||||||
|
pretrained=None):
|
||||||
|
super(SCRFD, self).__init__(backbone, neck, bbox_head, train_cfg,
|
||||||
|
test_cfg, pretrained)
|
||||||
|
|
||||||
|
def forward_train(self,
|
||||||
|
img,
|
||||||
|
img_metas,
|
||||||
|
gt_bboxes,
|
||||||
|
gt_labels,
|
||||||
|
gt_keypointss=None,
|
||||||
|
gt_bboxes_ignore=None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
img (Tensor): Input images of shape (N, C, H, W).
|
||||||
|
Typically these should be mean centered and std scaled.
|
||||||
|
img_metas (list[dict]): A List of image info dict where each dict
|
||||||
|
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
||||||
|
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
||||||
|
For details on the values of these keys see
|
||||||
|
:class:`mmdet.datasets.pipelines.Collect`.
|
||||||
|
gt_bboxes (list[Tensor]): Each item are the truth boxes for each
|
||||||
|
image in [tl_x, tl_y, br_x, br_y] format.
|
||||||
|
gt_labels (list[Tensor]): Class indices corresponding to each box
|
||||||
|
gt_bboxes_ignore (None | list[Tensor]): Specify which bounding
|
||||||
|
boxes can be ignored when computing the loss.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[str, Tensor]: A dictionary of loss components.
|
||||||
|
"""
|
||||||
|
super(SingleStageDetector, self).forward_train(img, img_metas)
|
||||||
|
x = self.extract_feat(img)
|
||||||
|
losses = self.bbox_head.forward_train(x, img_metas, gt_bboxes,
|
||||||
|
gt_labels, gt_keypointss,
|
||||||
|
gt_bboxes_ignore)
|
||||||
|
return losses
|
||||||
|
|
||||||
|
def simple_test(self, img, img_metas, rescale=False):
|
||||||
|
"""Test function without test time augmentation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
imgs (list[torch.Tensor]): List of multiple images
|
||||||
|
img_metas (list[dict]): List of image information.
|
||||||
|
rescale (bool, optional): Whether to rescale the results.
|
||||||
|
Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[list[np.ndarray]]: BBox results of each image and classes.
|
||||||
|
The outer list corresponds to each image. The inner list
|
||||||
|
corresponds to each class.
|
||||||
|
"""
|
||||||
|
x = self.extract_feat(img)
|
||||||
|
outs = self.bbox_head(x)
|
||||||
|
if torch.onnx.is_in_onnx_export():
|
||||||
|
print('single_stage.py in-onnx-export')
|
||||||
|
print(outs.__class__)
|
||||||
|
cls_score, bbox_pred, kps_pred = outs
|
||||||
|
for c in cls_score:
|
||||||
|
print(c.shape)
|
||||||
|
for c in bbox_pred:
|
||||||
|
print(c.shape)
|
||||||
|
if self.bbox_head.use_kps:
|
||||||
|
for c in kps_pred:
|
||||||
|
print(c.shape)
|
||||||
|
return (cls_score, bbox_pred, kps_pred)
|
||||||
|
else:
|
||||||
|
return (cls_score, bbox_pred)
|
||||||
|
bbox_list = self.bbox_head.get_bboxes(
|
||||||
|
*outs, img_metas, rescale=rescale)
|
||||||
|
|
||||||
|
# return kps if use_kps
|
||||||
|
if len(bbox_list[0]) == 2:
|
||||||
|
bbox_results = [
|
||||||
|
bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes)
|
||||||
|
for det_bboxes, det_labels in bbox_list
|
||||||
|
]
|
||||||
|
elif len(bbox_list[0]) == 3:
|
||||||
|
bbox_results = [
|
||||||
|
bbox2result(
|
||||||
|
det_bboxes,
|
||||||
|
det_labels,
|
||||||
|
self.bbox_head.num_classes,
|
||||||
|
kps=det_kps)
|
||||||
|
for det_bboxes, det_labels, det_kps in bbox_list
|
||||||
|
]
|
||||||
|
return bbox_results
|
||||||
|
|
||||||
|
def feature_test(self, img):
|
||||||
|
x = self.extract_feat(img)
|
||||||
|
outs = self.bbox_head(x)
|
||||||
|
return outs
|
||||||
0
modelscope/models/cv/face_recognition/__init__.py
Normal file
0
modelscope/models/cv/face_recognition/__init__.py
Normal file
50
modelscope/models/cv/face_recognition/align_face.py
Normal file
50
modelscope/models/cv/face_recognition/align_face.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from skimage import transform as trans
|
||||||
|
|
||||||
|
|
||||||
|
def align_face(image, size, lmks):
|
||||||
|
dst_w = size[1]
|
||||||
|
dst_h = size[0]
|
||||||
|
# landmark calculation of dst images
|
||||||
|
base_w = 96
|
||||||
|
base_h = 112
|
||||||
|
assert (dst_w >= base_w)
|
||||||
|
assert (dst_h >= base_h)
|
||||||
|
base_lmk = [
|
||||||
|
30.2946, 51.6963, 65.5318, 51.5014, 48.0252, 71.7366, 33.5493, 92.3655,
|
||||||
|
62.7299, 92.2041
|
||||||
|
]
|
||||||
|
|
||||||
|
dst_lmk = np.array(base_lmk).reshape((5, 2)).astype(np.float32)
|
||||||
|
if dst_w != base_w:
|
||||||
|
slide = (dst_w - base_w) / 2
|
||||||
|
dst_lmk[:, 0] += slide
|
||||||
|
|
||||||
|
if dst_h != base_h:
|
||||||
|
slide = (dst_h - base_h) / 2
|
||||||
|
dst_lmk[:, 1] += slide
|
||||||
|
|
||||||
|
src_lmk = lmks
|
||||||
|
# using skimage method
|
||||||
|
tform = trans.SimilarityTransform()
|
||||||
|
tform.estimate(src_lmk, dst_lmk)
|
||||||
|
t = tform.params[0:2, :]
|
||||||
|
|
||||||
|
assert (image.shape[2] == 3)
|
||||||
|
|
||||||
|
dst_image = cv2.warpAffine(image.copy(), t, (dst_w, dst_h))
|
||||||
|
dst_pts = GetAffinePoints(src_lmk, t)
|
||||||
|
return dst_image, dst_pts
|
||||||
|
|
||||||
|
|
||||||
|
def GetAffinePoints(pts_in, trans):
|
||||||
|
pts_out = pts_in.copy()
|
||||||
|
assert (pts_in.shape[1] == 2)
|
||||||
|
|
||||||
|
for k in range(pts_in.shape[0]):
|
||||||
|
pts_out[k, 0] = pts_in[k, 0] * trans[0, 0] + pts_in[k, 1] * trans[
|
||||||
|
0, 1] + trans[0, 2]
|
||||||
|
pts_out[k, 1] = pts_in[k, 0] * trans[1, 0] + pts_in[k, 1] * trans[
|
||||||
|
1, 1] + trans[1, 2]
|
||||||
|
return pts_out
|
||||||
0
modelscope/models/cv/face_recognition/torchkit/__init__.py
Executable file
0
modelscope/models/cv/face_recognition/torchkit/__init__.py
Executable file
31
modelscope/models/cv/face_recognition/torchkit/backbone/__init__.py
Executable file
31
modelscope/models/cv/face_recognition/torchkit/backbone/__init__.py
Executable file
@@ -0,0 +1,31 @@
|
|||||||
|
from .model_irse import (IR_18, IR_34, IR_50, IR_101, IR_152, IR_200, IR_SE_50,
|
||||||
|
IR_SE_101, IR_SE_152, IR_SE_200)
|
||||||
|
from .model_resnet import ResNet_50, ResNet_101, ResNet_152
|
||||||
|
|
||||||
|
_model_dict = {
|
||||||
|
'ResNet_50': ResNet_50,
|
||||||
|
'ResNet_101': ResNet_101,
|
||||||
|
'ResNet_152': ResNet_152,
|
||||||
|
'IR_18': IR_18,
|
||||||
|
'IR_34': IR_34,
|
||||||
|
'IR_50': IR_50,
|
||||||
|
'IR_101': IR_101,
|
||||||
|
'IR_152': IR_152,
|
||||||
|
'IR_200': IR_200,
|
||||||
|
'IR_SE_50': IR_SE_50,
|
||||||
|
'IR_SE_101': IR_SE_101,
|
||||||
|
'IR_SE_152': IR_SE_152,
|
||||||
|
'IR_SE_200': IR_SE_200
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_model(key):
|
||||||
|
""" Get different backbone network by key,
|
||||||
|
support ResNet50, ResNet_101, ResNet_152
|
||||||
|
IR_18, IR_34, IR_50, IR_101, IR_152, IR_200,
|
||||||
|
IR_SE_50, IR_SE_101, IR_SE_152, IR_SE_200.
|
||||||
|
"""
|
||||||
|
if key in _model_dict.keys():
|
||||||
|
return _model_dict[key]
|
||||||
|
else:
|
||||||
|
raise KeyError('not support model {}'.format(key))
|
||||||
68
modelscope/models/cv/face_recognition/torchkit/backbone/common.py
Executable file
68
modelscope/models/cv/face_recognition/torchkit/backbone/common.py
Executable file
@@ -0,0 +1,68 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import (BatchNorm1d, BatchNorm2d, Conv2d, Linear, Module, ReLU,
|
||||||
|
Sigmoid)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_weights(modules):
|
||||||
|
""" Weight initilize, conv2d and linear is initialized with kaiming_normal
|
||||||
|
"""
|
||||||
|
for m in modules:
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(
|
||||||
|
m.weight, mode='fan_out', nonlinearity='relu')
|
||||||
|
if m.bias is not None:
|
||||||
|
m.bias.data.zero_()
|
||||||
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
|
m.weight.data.fill_(1)
|
||||||
|
m.bias.data.zero_()
|
||||||
|
elif isinstance(m, nn.Linear):
|
||||||
|
nn.init.kaiming_normal_(
|
||||||
|
m.weight, mode='fan_out', nonlinearity='relu')
|
||||||
|
if m.bias is not None:
|
||||||
|
m.bias.data.zero_()
|
||||||
|
|
||||||
|
|
||||||
|
class Flatten(Module):
|
||||||
|
""" Flat tensor
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
return input.view(input.size(0), -1)
|
||||||
|
|
||||||
|
|
||||||
|
class SEModule(Module):
|
||||||
|
""" SE block
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, channels, reduction):
|
||||||
|
super(SEModule, self).__init__()
|
||||||
|
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||||
|
self.fc1 = Conv2d(
|
||||||
|
channels,
|
||||||
|
channels // reduction,
|
||||||
|
kernel_size=1,
|
||||||
|
padding=0,
|
||||||
|
bias=False)
|
||||||
|
|
||||||
|
nn.init.xavier_uniform_(self.fc1.weight.data)
|
||||||
|
|
||||||
|
self.relu = ReLU(inplace=True)
|
||||||
|
self.fc2 = Conv2d(
|
||||||
|
channels // reduction,
|
||||||
|
channels,
|
||||||
|
kernel_size=1,
|
||||||
|
padding=0,
|
||||||
|
bias=False)
|
||||||
|
|
||||||
|
self.sigmoid = Sigmoid()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
module_input = x
|
||||||
|
x = self.avg_pool(x)
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
x = self.sigmoid(x)
|
||||||
|
|
||||||
|
return module_input * x
|
||||||
279
modelscope/models/cv/face_recognition/torchkit/backbone/model_irse.py
Executable file
279
modelscope/models/cv/face_recognition/torchkit/backbone/model_irse.py
Executable file
@@ -0,0 +1,279 @@
|
|||||||
|
# based on:
|
||||||
|
# https://github.com/ZhaoJ9014/face.evoLVe.PyTorch/blob/master/backbone/model_irse.py
|
||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
|
from torch.nn import (BatchNorm1d, BatchNorm2d, Conv2d, Dropout, Linear,
|
||||||
|
MaxPool2d, Module, PReLU, Sequential)
|
||||||
|
|
||||||
|
from .common import Flatten, SEModule, initialize_weights
|
||||||
|
|
||||||
|
|
||||||
|
class BasicBlockIR(Module):
|
||||||
|
""" BasicBlock for IRNet
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_channel, depth, stride):
|
||||||
|
super(BasicBlockIR, self).__init__()
|
||||||
|
if in_channel == depth:
|
||||||
|
self.shortcut_layer = MaxPool2d(1, stride)
|
||||||
|
else:
|
||||||
|
self.shortcut_layer = Sequential(
|
||||||
|
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
|
||||||
|
BatchNorm2d(depth))
|
||||||
|
self.res_layer = Sequential(
|
||||||
|
BatchNorm2d(in_channel),
|
||||||
|
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
|
||||||
|
BatchNorm2d(depth), PReLU(depth),
|
||||||
|
Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
|
||||||
|
BatchNorm2d(depth))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
shortcut = self.shortcut_layer(x)
|
||||||
|
res = self.res_layer(x)
|
||||||
|
|
||||||
|
return res + shortcut
|
||||||
|
|
||||||
|
|
||||||
|
class BottleneckIR(Module):
|
||||||
|
""" BasicBlock with bottleneck for IRNet
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_channel, depth, stride):
|
||||||
|
super(BottleneckIR, self).__init__()
|
||||||
|
reduction_channel = depth // 4
|
||||||
|
if in_channel == depth:
|
||||||
|
self.shortcut_layer = MaxPool2d(1, stride)
|
||||||
|
else:
|
||||||
|
self.shortcut_layer = Sequential(
|
||||||
|
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
|
||||||
|
BatchNorm2d(depth))
|
||||||
|
self.res_layer = Sequential(
|
||||||
|
BatchNorm2d(in_channel),
|
||||||
|
Conv2d(
|
||||||
|
in_channel, reduction_channel, (1, 1), (1, 1), 0, bias=False),
|
||||||
|
BatchNorm2d(reduction_channel), PReLU(reduction_channel),
|
||||||
|
Conv2d(
|
||||||
|
reduction_channel,
|
||||||
|
reduction_channel, (3, 3), (1, 1),
|
||||||
|
1,
|
||||||
|
bias=False), BatchNorm2d(reduction_channel),
|
||||||
|
PReLU(reduction_channel),
|
||||||
|
Conv2d(reduction_channel, depth, (1, 1), stride, 0, bias=False),
|
||||||
|
BatchNorm2d(depth))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
shortcut = self.shortcut_layer(x)
|
||||||
|
res = self.res_layer(x)
|
||||||
|
|
||||||
|
return res + shortcut
|
||||||
|
|
||||||
|
|
||||||
|
class BasicBlockIRSE(BasicBlockIR):
|
||||||
|
|
||||||
|
def __init__(self, in_channel, depth, stride):
|
||||||
|
super(BasicBlockIRSE, self).__init__(in_channel, depth, stride)
|
||||||
|
self.res_layer.add_module('se_block', SEModule(depth, 16))
|
||||||
|
|
||||||
|
|
||||||
|
class BottleneckIRSE(BottleneckIR):
|
||||||
|
|
||||||
|
def __init__(self, in_channel, depth, stride):
|
||||||
|
super(BottleneckIRSE, self).__init__(in_channel, depth, stride)
|
||||||
|
self.res_layer.add_module('se_block', SEModule(depth, 16))
|
||||||
|
|
||||||
|
|
||||||
|
class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
|
||||||
|
'''A named tuple describing a ResNet block.'''
|
||||||
|
|
||||||
|
|
||||||
|
def get_block(in_channel, depth, num_units, stride=2):
|
||||||
|
|
||||||
|
return [Bottleneck(in_channel, depth, stride)] +\
|
||||||
|
[Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
|
||||||
|
|
||||||
|
|
||||||
|
def get_blocks(num_layers):
|
||||||
|
if num_layers == 18:
|
||||||
|
blocks = [
|
||||||
|
get_block(in_channel=64, depth=64, num_units=2),
|
||||||
|
get_block(in_channel=64, depth=128, num_units=2),
|
||||||
|
get_block(in_channel=128, depth=256, num_units=2),
|
||||||
|
get_block(in_channel=256, depth=512, num_units=2)
|
||||||
|
]
|
||||||
|
elif num_layers == 34:
|
||||||
|
blocks = [
|
||||||
|
get_block(in_channel=64, depth=64, num_units=3),
|
||||||
|
get_block(in_channel=64, depth=128, num_units=4),
|
||||||
|
get_block(in_channel=128, depth=256, num_units=6),
|
||||||
|
get_block(in_channel=256, depth=512, num_units=3)
|
||||||
|
]
|
||||||
|
elif num_layers == 50:
|
||||||
|
blocks = [
|
||||||
|
get_block(in_channel=64, depth=64, num_units=3),
|
||||||
|
get_block(in_channel=64, depth=128, num_units=4),
|
||||||
|
get_block(in_channel=128, depth=256, num_units=14),
|
||||||
|
get_block(in_channel=256, depth=512, num_units=3)
|
||||||
|
]
|
||||||
|
elif num_layers == 100:
|
||||||
|
blocks = [
|
||||||
|
get_block(in_channel=64, depth=64, num_units=3),
|
||||||
|
get_block(in_channel=64, depth=128, num_units=13),
|
||||||
|
get_block(in_channel=128, depth=256, num_units=30),
|
||||||
|
get_block(in_channel=256, depth=512, num_units=3)
|
||||||
|
]
|
||||||
|
elif num_layers == 152:
|
||||||
|
blocks = [
|
||||||
|
get_block(in_channel=64, depth=256, num_units=3),
|
||||||
|
get_block(in_channel=256, depth=512, num_units=8),
|
||||||
|
get_block(in_channel=512, depth=1024, num_units=36),
|
||||||
|
get_block(in_channel=1024, depth=2048, num_units=3)
|
||||||
|
]
|
||||||
|
elif num_layers == 200:
|
||||||
|
blocks = [
|
||||||
|
get_block(in_channel=64, depth=256, num_units=3),
|
||||||
|
get_block(in_channel=256, depth=512, num_units=24),
|
||||||
|
get_block(in_channel=512, depth=1024, num_units=36),
|
||||||
|
get_block(in_channel=1024, depth=2048, num_units=3)
|
||||||
|
]
|
||||||
|
|
||||||
|
return blocks
|
||||||
|
|
||||||
|
|
||||||
|
class Backbone(Module):
|
||||||
|
|
||||||
|
def __init__(self, input_size, num_layers, mode='ir'):
|
||||||
|
""" Args:
|
||||||
|
input_size: input_size of backbone
|
||||||
|
num_layers: num_layers of backbone
|
||||||
|
mode: support ir or irse
|
||||||
|
"""
|
||||||
|
super(Backbone, self).__init__()
|
||||||
|
assert input_size[0] in [112, 224], \
|
||||||
|
'input_size should be [112, 112] or [224, 224]'
|
||||||
|
assert num_layers in [18, 34, 50, 100, 152, 200], \
|
||||||
|
'num_layers should be 18, 34, 50, 100 or 152'
|
||||||
|
assert mode in ['ir', 'ir_se'], \
|
||||||
|
'mode should be ir or ir_se'
|
||||||
|
self.input_layer = Sequential(
|
||||||
|
Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64),
|
||||||
|
PReLU(64))
|
||||||
|
blocks = get_blocks(num_layers)
|
||||||
|
if num_layers <= 100:
|
||||||
|
if mode == 'ir':
|
||||||
|
unit_module = BasicBlockIR
|
||||||
|
elif mode == 'ir_se':
|
||||||
|
unit_module = BasicBlockIRSE
|
||||||
|
output_channel = 512
|
||||||
|
else:
|
||||||
|
if mode == 'ir':
|
||||||
|
unit_module = BottleneckIR
|
||||||
|
elif mode == 'ir_se':
|
||||||
|
unit_module = BottleneckIRSE
|
||||||
|
output_channel = 2048
|
||||||
|
|
||||||
|
if input_size[0] == 112:
|
||||||
|
self.output_layer = Sequential(
|
||||||
|
BatchNorm2d(output_channel), Dropout(0.4), Flatten(),
|
||||||
|
Linear(output_channel * 7 * 7, 512),
|
||||||
|
BatchNorm1d(512, affine=False))
|
||||||
|
else:
|
||||||
|
self.output_layer = Sequential(
|
||||||
|
BatchNorm2d(output_channel), Dropout(0.4), Flatten(),
|
||||||
|
Linear(output_channel * 14 * 14, 512),
|
||||||
|
BatchNorm1d(512, affine=False))
|
||||||
|
|
||||||
|
modules = []
|
||||||
|
for block in blocks:
|
||||||
|
for bottleneck in block:
|
||||||
|
modules.append(
|
||||||
|
unit_module(bottleneck.in_channel, bottleneck.depth,
|
||||||
|
bottleneck.stride))
|
||||||
|
self.body = Sequential(*modules)
|
||||||
|
|
||||||
|
initialize_weights(self.modules())
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.input_layer(x)
|
||||||
|
x = self.body(x)
|
||||||
|
x = self.output_layer(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def IR_18(input_size):
|
||||||
|
""" Constructs a ir-18 model.
|
||||||
|
"""
|
||||||
|
model = Backbone(input_size, 18, 'ir')
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def IR_34(input_size):
|
||||||
|
""" Constructs a ir-34 model.
|
||||||
|
"""
|
||||||
|
model = Backbone(input_size, 34, 'ir')
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def IR_50(input_size):
|
||||||
|
""" Constructs a ir-50 model.
|
||||||
|
"""
|
||||||
|
model = Backbone(input_size, 50, 'ir')
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def IR_101(input_size):
|
||||||
|
""" Constructs a ir-101 model.
|
||||||
|
"""
|
||||||
|
model = Backbone(input_size, 100, 'ir')
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def IR_152(input_size):
|
||||||
|
""" Constructs a ir-152 model.
|
||||||
|
"""
|
||||||
|
model = Backbone(input_size, 152, 'ir')
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def IR_200(input_size):
|
||||||
|
""" Constructs a ir-200 model.
|
||||||
|
"""
|
||||||
|
model = Backbone(input_size, 200, 'ir')
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def IR_SE_50(input_size):
|
||||||
|
""" Constructs a ir_se-50 model.
|
||||||
|
"""
|
||||||
|
model = Backbone(input_size, 50, 'ir_se')
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def IR_SE_101(input_size):
|
||||||
|
""" Constructs a ir_se-101 model.
|
||||||
|
"""
|
||||||
|
model = Backbone(input_size, 100, 'ir_se')
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def IR_SE_152(input_size):
|
||||||
|
""" Constructs a ir_se-152 model.
|
||||||
|
"""
|
||||||
|
model = Backbone(input_size, 152, 'ir_se')
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def IR_SE_200(input_size):
|
||||||
|
""" Constructs a ir_se-200 model.
|
||||||
|
"""
|
||||||
|
model = Backbone(input_size, 200, 'ir_se')
|
||||||
|
|
||||||
|
return model
|
||||||
162
modelscope/models/cv/face_recognition/torchkit/backbone/model_resnet.py
Executable file
162
modelscope/models/cv/face_recognition/torchkit/backbone/model_resnet.py
Executable file
@@ -0,0 +1,162 @@
|
|||||||
|
# based on:
|
||||||
|
# https://github.com/ZhaoJ9014/face.evoLVe.PyTorch/blob/master/backbone/model_resnet.py
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import (BatchNorm1d, BatchNorm2d, Conv2d, Dropout, Linear,
|
||||||
|
MaxPool2d, Module, ReLU, Sequential)
|
||||||
|
|
||||||
|
from .common import initialize_weights
|
||||||
|
|
||||||
|
|
||||||
|
def conv3x3(in_planes, out_planes, stride=1):
|
||||||
|
""" 3x3 convolution with padding
|
||||||
|
"""
|
||||||
|
return Conv2d(
|
||||||
|
in_planes,
|
||||||
|
out_planes,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=stride,
|
||||||
|
padding=1,
|
||||||
|
bias=False)
|
||||||
|
|
||||||
|
|
||||||
|
def conv1x1(in_planes, out_planes, stride=1):
|
||||||
|
""" 1x1 convolution
|
||||||
|
"""
|
||||||
|
return Conv2d(
|
||||||
|
in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||||
|
|
||||||
|
|
||||||
|
class Bottleneck(Module):
|
||||||
|
expansion = 4
|
||||||
|
|
||||||
|
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||||
|
super(Bottleneck, self).__init__()
|
||||||
|
self.conv1 = conv1x1(inplanes, planes)
|
||||||
|
self.bn1 = BatchNorm2d(planes)
|
||||||
|
self.conv2 = conv3x3(planes, planes, stride)
|
||||||
|
self.bn2 = BatchNorm2d(planes)
|
||||||
|
self.conv3 = conv1x1(planes, planes * self.expansion)
|
||||||
|
self.bn3 = BatchNorm2d(planes * self.expansion)
|
||||||
|
self.relu = ReLU(inplace=True)
|
||||||
|
self.downsample = downsample
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
identity = x
|
||||||
|
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.bn1(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
out = self.conv2(out)
|
||||||
|
out = self.bn2(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
out = self.conv3(out)
|
||||||
|
out = self.bn3(out)
|
||||||
|
|
||||||
|
if self.downsample is not None:
|
||||||
|
identity = self.downsample(x)
|
||||||
|
|
||||||
|
out += identity
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class ResNet(Module):
|
||||||
|
""" ResNet backbone
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, input_size, block, layers, zero_init_residual=True):
|
||||||
|
""" Args:
|
||||||
|
input_size: input_size of backbone
|
||||||
|
block: block function
|
||||||
|
layers: layers in each block
|
||||||
|
"""
|
||||||
|
super(ResNet, self).__init__()
|
||||||
|
assert input_size[0] in [112, 224],\
|
||||||
|
'input_size should be [112, 112] or [224, 224]'
|
||||||
|
self.inplanes = 64
|
||||||
|
self.conv1 = Conv2d(
|
||||||
|
3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
||||||
|
self.bn1 = BatchNorm2d(64)
|
||||||
|
self.relu = ReLU(inplace=True)
|
||||||
|
self.maxpool = MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||||
|
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||||
|
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
||||||
|
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
||||||
|
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
||||||
|
|
||||||
|
self.bn_o1 = BatchNorm2d(2048)
|
||||||
|
self.dropout = Dropout()
|
||||||
|
if input_size[0] == 112:
|
||||||
|
self.fc = Linear(2048 * 4 * 4, 512)
|
||||||
|
else:
|
||||||
|
self.fc = Linear(2048 * 7 * 7, 512)
|
||||||
|
self.bn_o2 = BatchNorm1d(512)
|
||||||
|
|
||||||
|
initialize_weights(self.modules)
|
||||||
|
if zero_init_residual:
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, Bottleneck):
|
||||||
|
nn.init.constant_(m.bn3.weight, 0)
|
||||||
|
|
||||||
|
def _make_layer(self, block, planes, blocks, stride=1):
|
||||||
|
downsample = None
|
||||||
|
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||||
|
downsample = Sequential(
|
||||||
|
conv1x1(self.inplanes, planes * block.expansion, stride),
|
||||||
|
BatchNorm2d(planes * block.expansion),
|
||||||
|
)
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
layers.append(block(self.inplanes, planes, stride, downsample))
|
||||||
|
self.inplanes = planes * block.expansion
|
||||||
|
for _ in range(1, blocks):
|
||||||
|
layers.append(block(self.inplanes, planes))
|
||||||
|
|
||||||
|
return Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.bn1(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.maxpool(x)
|
||||||
|
|
||||||
|
x = self.layer1(x)
|
||||||
|
x = self.layer2(x)
|
||||||
|
x = self.layer3(x)
|
||||||
|
x = self.layer4(x)
|
||||||
|
|
||||||
|
x = self.bn_o1(x)
|
||||||
|
x = self.dropout(x)
|
||||||
|
x = x.view(x.size(0), -1)
|
||||||
|
x = self.fc(x)
|
||||||
|
x = self.bn_o2(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def ResNet_50(input_size, **kwargs):
|
||||||
|
""" Constructs a ResNet-50 model.
|
||||||
|
"""
|
||||||
|
model = ResNet(input_size, Bottleneck, [3, 4, 6, 3], **kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def ResNet_101(input_size, **kwargs):
|
||||||
|
""" Constructs a ResNet-101 model.
|
||||||
|
"""
|
||||||
|
model = ResNet(input_size, Bottleneck, [3, 4, 23, 3], **kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def ResNet_152(input_size, **kwargs):
|
||||||
|
""" Constructs a ResNet-152 model.
|
||||||
|
"""
|
||||||
|
model = ResNet(input_size, Bottleneck, [3, 8, 36, 3], **kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
@@ -13,6 +13,7 @@ class OutputKeys(object):
|
|||||||
POSES = 'poses'
|
POSES = 'poses'
|
||||||
CAPTION = 'caption'
|
CAPTION = 'caption'
|
||||||
BOXES = 'boxes'
|
BOXES = 'boxes'
|
||||||
|
KEYPOINTS = 'keypoints'
|
||||||
MASKS = 'masks'
|
MASKS = 'masks'
|
||||||
TEXT = 'text'
|
TEXT = 'text'
|
||||||
POLYGONS = 'polygons'
|
POLYGONS = 'polygons'
|
||||||
@@ -55,6 +56,31 @@ TASK_OUTPUTS = {
|
|||||||
Tasks.object_detection:
|
Tasks.object_detection:
|
||||||
[OutputKeys.SCORES, OutputKeys.LABELS, OutputKeys.BOXES],
|
[OutputKeys.SCORES, OutputKeys.LABELS, OutputKeys.BOXES],
|
||||||
|
|
||||||
|
# face detection result for single sample
|
||||||
|
# {
|
||||||
|
# "scores": [0.9, 0.1, 0.05, 0.05]
|
||||||
|
# "boxes": [
|
||||||
|
# [x1, y1, x2, y2],
|
||||||
|
# [x1, y1, x2, y2],
|
||||||
|
# [x1, y1, x2, y2],
|
||||||
|
# [x1, y1, x2, y2],
|
||||||
|
# ],
|
||||||
|
# "keypoints": [
|
||||||
|
# [x1, y1, x2, y2, x3, y3, x4, y4, x5, y5],
|
||||||
|
# [x1, y1, x2, y2, x3, y3, x4, y4, x5, y5],
|
||||||
|
# [x1, y1, x2, y2, x3, y3, x4, y4, x5, y5],
|
||||||
|
# [x1, y1, x2, y2, x3, y3, x4, y4, x5, y5],
|
||||||
|
# ],
|
||||||
|
# }
|
||||||
|
Tasks.face_detection:
|
||||||
|
[OutputKeys.SCORES, OutputKeys.BOXES, OutputKeys.KEYPOINTS],
|
||||||
|
|
||||||
|
# face recognition result for single sample
|
||||||
|
# {
|
||||||
|
# "img_embedding": np.array with shape [1, D],
|
||||||
|
# }
|
||||||
|
Tasks.face_recognition: [OutputKeys.IMG_EMBEDDING],
|
||||||
|
|
||||||
# instance segmentation result for single sample
|
# instance segmentation result for single sample
|
||||||
# {
|
# {
|
||||||
# "scores": [0.9, 0.1, 0.05, 0.05],
|
# "scores": [0.9, 0.1, 0.05, 0.05],
|
||||||
|
|||||||
@@ -255,7 +255,11 @@ class Pipeline(ABC):
|
|||||||
elif isinstance(data, InputFeatures):
|
elif isinstance(data, InputFeatures):
|
||||||
return data
|
return data
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unsupported data type {type(data)}')
|
import mmcv
|
||||||
|
if isinstance(data, mmcv.parallel.data_container.DataContainer):
|
||||||
|
return data
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unsupported data type {type(data)}')
|
||||||
|
|
||||||
def _process_single(self, input: Input, *args, **kwargs) -> Dict[str, Any]:
|
def _process_single(self, input: Input, *args, **kwargs) -> Dict[str, Any]:
|
||||||
preprocess_params = kwargs.get('preprocess_params')
|
preprocess_params = kwargs.get('preprocess_params')
|
||||||
|
|||||||
@@ -80,6 +80,10 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
|||||||
Tasks.text_to_image_synthesis:
|
Tasks.text_to_image_synthesis:
|
||||||
(Pipelines.text_to_image_synthesis,
|
(Pipelines.text_to_image_synthesis,
|
||||||
'damo/cv_imagen_text-to-image-synthesis_tiny'),
|
'damo/cv_imagen_text-to-image-synthesis_tiny'),
|
||||||
|
Tasks.face_detection: (Pipelines.face_detection,
|
||||||
|
'damo/cv_resnet_facedetection_scrfd10gkps'),
|
||||||
|
Tasks.face_recognition: (Pipelines.face_recognition,
|
||||||
|
'damo/cv_ir101_facerecognition_cfglint'),
|
||||||
Tasks.video_multi_modal_embedding:
|
Tasks.video_multi_modal_embedding:
|
||||||
(Pipelines.video_multi_modal_embedding,
|
(Pipelines.video_multi_modal_embedding,
|
||||||
'damo/multi_modal_clip_vtretrival_msrvtt_53'),
|
'damo/multi_modal_clip_vtretrival_msrvtt_53'),
|
||||||
|
|||||||
@@ -5,44 +5,50 @@ from modelscope.utils.import_utils import LazyImportModule
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .action_recognition_pipeline import ActionRecognitionPipeline
|
from .action_recognition_pipeline import ActionRecognitionPipeline
|
||||||
from .animal_recog_pipeline import AnimalRecogPipeline
|
from .animal_recognition_pipeline import AnimalRecognitionPipeline
|
||||||
from .cmdssl_video_embedding_pipleline import CMDSSLVideoEmbeddingPipeline
|
from .cmdssl_video_embedding_pipeline import CMDSSLVideoEmbeddingPipeline
|
||||||
from .live_category_pipeline import LiveCategoryPipeline
|
from .face_detection_pipeline import FaceDetectionPipeline
|
||||||
from .image_classification_pipeline import GeneralImageClassificationPipeline
|
from .face_recognition_pipeline import FaceRecognitionPipeline
|
||||||
from .face_image_generation_pipeline import FaceImageGenerationPipeline
|
from .face_image_generation_pipeline import FaceImageGenerationPipeline
|
||||||
from .image_cartoon_pipeline import ImageCartoonPipeline
|
from .image_cartoon_pipeline import ImageCartoonPipeline
|
||||||
|
from .image_classification_pipeline import GeneralImageClassificationPipeline
|
||||||
from .image_denoise_pipeline import ImageDenoisePipeline
|
from .image_denoise_pipeline import ImageDenoisePipeline
|
||||||
from .image_color_enhance_pipeline import ImageColorEnhancePipeline
|
from .image_color_enhance_pipeline import ImageColorEnhancePipeline
|
||||||
from .image_colorization_pipeline import ImageColorizationPipeline
|
from .image_colorization_pipeline import ImageColorizationPipeline
|
||||||
from .image_instance_segmentation_pipeline import ImageInstanceSegmentationPipeline
|
from .image_instance_segmentation_pipeline import ImageInstanceSegmentationPipeline
|
||||||
from .image_to_image_translation_pipeline import Image2ImageTranslationPipeline
|
|
||||||
from .video_category_pipeline import VideoCategoryPipeline
|
|
||||||
from .image_matting_pipeline import ImageMattingPipeline
|
from .image_matting_pipeline import ImageMattingPipeline
|
||||||
from .image_super_resolution_pipeline import ImageSuperResolutionPipeline
|
from .image_super_resolution_pipeline import ImageSuperResolutionPipeline
|
||||||
|
from .image_to_image_translation_pipeline import Image2ImageTranslationPipeline
|
||||||
from .style_transfer_pipeline import StyleTransferPipeline
|
from .style_transfer_pipeline import StyleTransferPipeline
|
||||||
|
from .live_category_pipeline import LiveCategoryPipeline
|
||||||
from .ocr_detection_pipeline import OCRDetectionPipeline
|
from .ocr_detection_pipeline import OCRDetectionPipeline
|
||||||
|
from .video_category_pipeline import VideoCategoryPipeline
|
||||||
from .virtual_tryon_pipeline import VirtualTryonPipeline
|
from .virtual_tryon_pipeline import VirtualTryonPipeline
|
||||||
else:
|
else:
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
'action_recognition_pipeline': ['ActionRecognitionPipeline'],
|
'action_recognition_pipeline': ['ActionRecognitionPipeline'],
|
||||||
'animal_recog_pipeline': ['AnimalRecogPipeline'],
|
'animal_recognition_pipeline': ['AnimalRecognitionPipeline'],
|
||||||
'cmdssl_video_embedding_pipleline': ['CMDSSLVideoEmbeddingPipeline'],
|
'cmdssl_video_embedding_pipeline': ['CMDSSLVideoEmbeddingPipeline'],
|
||||||
|
'face_detection_pipeline': ['FaceDetectionPipeline'],
|
||||||
|
'face_image_generation_pipeline': ['FaceImageGenerationPipeline'],
|
||||||
|
'face_recognition_pipeline': ['FaceRecognitionPipeline'],
|
||||||
'image_classification_pipeline':
|
'image_classification_pipeline':
|
||||||
['GeneralImageClassificationPipeline'],
|
['GeneralImageClassificationPipeline'],
|
||||||
'image_color_enhance_pipeline': ['ImageColorEnhancePipeline'],
|
|
||||||
'virtual_tryon_pipeline': ['VirtualTryonPipeline'],
|
|
||||||
'image_colorization_pipeline': ['ImageColorizationPipeline'],
|
|
||||||
'image_super_resolution_pipeline': ['ImageSuperResolutionPipeline'],
|
|
||||||
'image_denoise_pipeline': ['ImageDenoisePipeline'],
|
|
||||||
'face_image_generation_pipeline': ['FaceImageGenerationPipeline'],
|
|
||||||
'image_cartoon_pipeline': ['ImageCartoonPipeline'],
|
'image_cartoon_pipeline': ['ImageCartoonPipeline'],
|
||||||
'image_matting_pipeline': ['ImageMattingPipeline'],
|
'image_denoise_pipeline': ['ImageDenoisePipeline'],
|
||||||
'style_transfer_pipeline': ['StyleTransferPipeline'],
|
'image_color_enhance_pipeline': ['ImageColorEnhancePipeline'],
|
||||||
'ocr_detection_pipeline': ['OCRDetectionPipeline'],
|
'image_colorization_pipeline': ['ImageColorizationPipeline'],
|
||||||
'image_instance_segmentation_pipeline':
|
'image_instance_segmentation_pipeline':
|
||||||
['ImageInstanceSegmentationPipeline'],
|
['ImageInstanceSegmentationPipeline'],
|
||||||
'video_category_pipeline': ['VideoCategoryPipeline'],
|
'image_matting_pipeline': ['ImageMattingPipeline'],
|
||||||
|
'image_super_resolution_pipeline': ['ImageSuperResolutionPipeline'],
|
||||||
|
'image_to_image_translation_pipeline':
|
||||||
|
['Image2ImageTranslationPipeline'],
|
||||||
'live_category_pipeline': ['LiveCategoryPipeline'],
|
'live_category_pipeline': ['LiveCategoryPipeline'],
|
||||||
|
'ocr_detection_pipeline': ['OCRDetectionPipeline'],
|
||||||
|
'style_transfer_pipeline': ['StyleTransferPipeline'],
|
||||||
|
'video_category_pipeline': ['VideoCategoryPipeline'],
|
||||||
|
'virtual_tryon_pipeline': ['VirtualTryonPipeline'],
|
||||||
}
|
}
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ class ActionRecognitionPipeline(Pipeline):
|
|||||||
|
|
||||||
def __init__(self, model: str, **kwargs):
|
def __init__(self, model: str, **kwargs):
|
||||||
"""
|
"""
|
||||||
use `model` and `preprocessor` to create a kws pipeline for prediction
|
use `model` to create a action recognition pipeline for prediction
|
||||||
Args:
|
Args:
|
||||||
model: model id on modelscope hub.
|
model: model id on modelscope hub.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -22,11 +22,11 @@ logger = get_logger()
|
|||||||
|
|
||||||
@PIPELINES.register_module(
|
@PIPELINES.register_module(
|
||||||
Tasks.image_classification, module_name=Pipelines.animal_recognation)
|
Tasks.image_classification, module_name=Pipelines.animal_recognation)
|
||||||
class AnimalRecogPipeline(Pipeline):
|
class AnimalRecognitionPipeline(Pipeline):
|
||||||
|
|
||||||
def __init__(self, model: str, **kwargs):
|
def __init__(self, model: str, **kwargs):
|
||||||
"""
|
"""
|
||||||
use `model` and `preprocessor` to create a kws pipeline for prediction
|
use `model` to create a animal recognition pipeline for prediction
|
||||||
Args:
|
Args:
|
||||||
model: model id on modelscope hub.
|
model: model id on modelscope hub.
|
||||||
"""
|
"""
|
||||||
@@ -24,7 +24,7 @@ class CMDSSLVideoEmbeddingPipeline(Pipeline):
|
|||||||
|
|
||||||
def __init__(self, model: str, **kwargs):
|
def __init__(self, model: str, **kwargs):
|
||||||
"""
|
"""
|
||||||
use `model` and `preprocessor` to create a kws pipeline for prediction
|
use `model` to create a CMDSSL Video Embedding pipeline for prediction
|
||||||
Args:
|
Args:
|
||||||
model: model id on modelscope hub.
|
model: model id on modelscope hub.
|
||||||
"""
|
"""
|
||||||
105
modelscope/pipelines/cv/face_detection_pipeline.py
Normal file
105
modelscope/pipelines/cv/face_detection_pipeline.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
import os.path as osp
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import PIL
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from modelscope.metainfo import Pipelines
|
||||||
|
from modelscope.outputs import OutputKeys
|
||||||
|
from modelscope.pipelines.base import Input, Pipeline
|
||||||
|
from modelscope.pipelines.builder import PIPELINES
|
||||||
|
from modelscope.preprocessors import LoadImage
|
||||||
|
from modelscope.utils.constant import ModelFile, Tasks
|
||||||
|
from modelscope.utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
@PIPELINES.register_module(
|
||||||
|
Tasks.face_detection, module_name=Pipelines.face_detection)
|
||||||
|
class FaceDetectionPipeline(Pipeline):
|
||||||
|
|
||||||
|
def __init__(self, model: str, **kwargs):
|
||||||
|
"""
|
||||||
|
use `model` to create a face detection pipeline for prediction
|
||||||
|
Args:
|
||||||
|
model: model id on modelscope hub.
|
||||||
|
"""
|
||||||
|
super().__init__(model=model, **kwargs)
|
||||||
|
from mmcv import Config
|
||||||
|
from mmcv.parallel import MMDataParallel
|
||||||
|
from mmcv.runner import load_checkpoint
|
||||||
|
from mmdet.models import build_detector
|
||||||
|
from modelscope.models.cv.face_detection.mmdet_patch.datasets import RetinaFaceDataset
|
||||||
|
from modelscope.models.cv.face_detection.mmdet_patch.datasets.pipelines import RandomSquareCrop
|
||||||
|
from modelscope.models.cv.face_detection.mmdet_patch.models.backbones import ResNetV1e
|
||||||
|
from modelscope.models.cv.face_detection.mmdet_patch.models.dense_heads import SCRFDHead
|
||||||
|
from modelscope.models.cv.face_detection.mmdet_patch.models.detectors import SCRFD
|
||||||
|
cfg = Config.fromfile(osp.join(model, 'mmcv_scrfd_10g_bnkps.py'))
|
||||||
|
detector = build_detector(
|
||||||
|
cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
|
||||||
|
ckpt_path = osp.join(model, ModelFile.TORCH_MODEL_BIN_FILE)
|
||||||
|
logger.info(f'loading model from {ckpt_path}')
|
||||||
|
device = torch.device(
|
||||||
|
f'cuda:{0}' if torch.cuda.is_available() else 'cpu')
|
||||||
|
load_checkpoint(detector, ckpt_path, map_location=device)
|
||||||
|
detector = MMDataParallel(detector, device_ids=[0])
|
||||||
|
detector.eval()
|
||||||
|
self.detector = detector
|
||||||
|
logger.info('load model done')
|
||||||
|
|
||||||
|
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||||
|
img = LoadImage.convert_to_ndarray(input)
|
||||||
|
img = img.astype(np.float32)
|
||||||
|
pre_pipeline = [
|
||||||
|
dict(
|
||||||
|
type='MultiScaleFlipAug',
|
||||||
|
img_scale=(640, 640),
|
||||||
|
flip=False,
|
||||||
|
transforms=[
|
||||||
|
dict(type='Resize', keep_ratio=True),
|
||||||
|
dict(type='RandomFlip', flip_ratio=0.0),
|
||||||
|
dict(
|
||||||
|
type='Normalize',
|
||||||
|
mean=[127.5, 127.5, 127.5],
|
||||||
|
std=[128.0, 128.0, 128.0],
|
||||||
|
to_rgb=False),
|
||||||
|
dict(type='Pad', size=(640, 640), pad_val=0),
|
||||||
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
dict(type='Collect', keys=['img'])
|
||||||
|
])
|
||||||
|
]
|
||||||
|
from mmdet.datasets.pipelines import Compose
|
||||||
|
pipeline = Compose(pre_pipeline)
|
||||||
|
result = {}
|
||||||
|
result['filename'] = ''
|
||||||
|
result['ori_filename'] = ''
|
||||||
|
result['img'] = img
|
||||||
|
result['img_shape'] = img.shape
|
||||||
|
result['ori_shape'] = img.shape
|
||||||
|
result['img_fields'] = ['img']
|
||||||
|
result = pipeline(result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
|
||||||
|
result = self.detector(
|
||||||
|
return_loss=False,
|
||||||
|
rescale=True,
|
||||||
|
img=[input['img'][0].unsqueeze(0)],
|
||||||
|
img_metas=[[dict(input['img_metas'][0].data)]])
|
||||||
|
assert result is not None
|
||||||
|
result = result[0][0]
|
||||||
|
bboxes = result[:, :4].tolist()
|
||||||
|
kpss = result[:, 5:].tolist()
|
||||||
|
scores = result[:, 4].tolist()
|
||||||
|
return {
|
||||||
|
OutputKeys.SCORES: scores,
|
||||||
|
OutputKeys.BOXES: bboxes,
|
||||||
|
OutputKeys.KEYPOINTS: kpss
|
||||||
|
}
|
||||||
|
|
||||||
|
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
return inputs
|
||||||
@@ -24,7 +24,7 @@ class FaceImageGenerationPipeline(Pipeline):
|
|||||||
|
|
||||||
def __init__(self, model: str, **kwargs):
|
def __init__(self, model: str, **kwargs):
|
||||||
"""
|
"""
|
||||||
use `model` to create a kws pipeline for prediction
|
use `model` to create a face image generation pipeline for prediction
|
||||||
Args:
|
Args:
|
||||||
model: model id on modelscope hub.
|
model: model id on modelscope hub.
|
||||||
"""
|
"""
|
||||||
|
|||||||
130
modelscope/pipelines/cv/face_recognition_pipeline.py
Normal file
130
modelscope/pipelines/cv/face_recognition_pipeline.py
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
import os.path as osp
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import PIL
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from modelscope.metainfo import Pipelines
|
||||||
|
from modelscope.models.cv.face_recognition.align_face import align_face
|
||||||
|
from modelscope.models.cv.face_recognition.torchkit.backbone import get_model
|
||||||
|
from modelscope.outputs import OutputKeys
|
||||||
|
from modelscope.pipelines import pipeline
|
||||||
|
from modelscope.pipelines.base import Input, Pipeline
|
||||||
|
from modelscope.pipelines.builder import PIPELINES
|
||||||
|
from modelscope.preprocessors import LoadImage
|
||||||
|
from modelscope.utils.constant import ModelFile, Tasks
|
||||||
|
from modelscope.utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
@PIPELINES.register_module(
|
||||||
|
Tasks.face_recognition, module_name=Pipelines.face_recognition)
|
||||||
|
class FaceRecognitionPipeline(Pipeline):
|
||||||
|
|
||||||
|
def __init__(self, model: str, face_detection: Pipeline, **kwargs):
|
||||||
|
"""
|
||||||
|
use `model` to create a face recognition pipeline for prediction
|
||||||
|
Args:
|
||||||
|
model: model id on modelscope hub.
|
||||||
|
face_detecion: pipeline for face detection and face alignment before recognition
|
||||||
|
"""
|
||||||
|
|
||||||
|
# face recong model
|
||||||
|
super().__init__(model=model, **kwargs)
|
||||||
|
device = torch.device(
|
||||||
|
f'cuda:{0}' if torch.cuda.is_available() else 'cpu')
|
||||||
|
self.device = device
|
||||||
|
face_model = get_model('IR_101')([112, 112])
|
||||||
|
face_model.load_state_dict(
|
||||||
|
torch.load(
|
||||||
|
osp.join(model, ModelFile.TORCH_MODEL_BIN_FILE),
|
||||||
|
map_location=device))
|
||||||
|
face_model = face_model.to(device)
|
||||||
|
face_model.eval()
|
||||||
|
self.face_model = face_model
|
||||||
|
logger.info('face recognition model loaded!')
|
||||||
|
# face detect pipeline
|
||||||
|
self.face_detection = face_detection
|
||||||
|
|
||||||
|
def _choose_face(self,
|
||||||
|
det_result,
|
||||||
|
min_face=10,
|
||||||
|
top_face=1,
|
||||||
|
center_face=False):
|
||||||
|
'''
|
||||||
|
choose face with maximum area
|
||||||
|
Args:
|
||||||
|
det_result: output of face detection pipeline
|
||||||
|
min_face: minimum size of valid face w/h
|
||||||
|
top_face: take faces with top max areas
|
||||||
|
center_face: choose the most centerd face from multi faces, only valid if top_face > 1
|
||||||
|
'''
|
||||||
|
bboxes = np.array(det_result[OutputKeys.BOXES])
|
||||||
|
landmarks = np.array(det_result[OutputKeys.KEYPOINTS])
|
||||||
|
# scores = np.array(det_result[OutputKeys.SCORES])
|
||||||
|
if bboxes.shape[0] == 0:
|
||||||
|
logger.info('No face detected!')
|
||||||
|
return None
|
||||||
|
# face idx with enough size
|
||||||
|
face_idx = []
|
||||||
|
for i in range(bboxes.shape[0]):
|
||||||
|
box = bboxes[i]
|
||||||
|
if (box[2] - box[0]) >= min_face and (box[3] - box[1]) >= min_face:
|
||||||
|
face_idx += [i]
|
||||||
|
if len(face_idx) == 0:
|
||||||
|
logger.info(
|
||||||
|
f'Face size not enough, less than {min_face}x{min_face}!')
|
||||||
|
return None
|
||||||
|
bboxes = bboxes[face_idx]
|
||||||
|
landmarks = landmarks[face_idx]
|
||||||
|
# find max faces
|
||||||
|
boxes = np.array(bboxes)
|
||||||
|
area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
||||||
|
sort_idx = np.argsort(area)[-top_face:]
|
||||||
|
# find center face
|
||||||
|
if top_face > 1 and center_face and bboxes.shape[0] > 1:
|
||||||
|
img_center = [img.shape[1] // 2, img.shape[0] // 2]
|
||||||
|
min_dist = float('inf')
|
||||||
|
sel_idx = -1
|
||||||
|
for _idx in sort_idx:
|
||||||
|
box = boxes[_idx]
|
||||||
|
dist = np.square(
|
||||||
|
np.abs((box[0] + box[2]) / 2 - img_center[0])) + np.square(
|
||||||
|
np.abs((box[1] + box[3]) / 2 - img_center[1]))
|
||||||
|
if dist < min_dist:
|
||||||
|
min_dist = dist
|
||||||
|
sel_idx = _idx
|
||||||
|
sort_idx = [sel_idx]
|
||||||
|
main_idx = sort_idx[-1]
|
||||||
|
return bboxes[main_idx], landmarks[main_idx]
|
||||||
|
|
||||||
|
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||||
|
img = LoadImage.convert_to_ndarray(input)
|
||||||
|
img = img[:, :, ::-1]
|
||||||
|
det_result = self.face_detection(img.copy())
|
||||||
|
rtn = self._choose_face(det_result)
|
||||||
|
face_img = None
|
||||||
|
if rtn is not None:
|
||||||
|
_, face_lmks = rtn
|
||||||
|
face_lmks = face_lmks.reshape(5, 2)
|
||||||
|
align_img, _ = align_face(img, (112, 112), face_lmks)
|
||||||
|
face_img = align_img[:, :, ::-1] # to rgb
|
||||||
|
face_img = np.transpose(face_img, axes=(2, 0, 1))
|
||||||
|
face_img = (face_img / 255. - 0.5) / 0.5
|
||||||
|
face_img = face_img.astype(np.float32)
|
||||||
|
result = {}
|
||||||
|
result['img'] = face_img
|
||||||
|
return result
|
||||||
|
|
||||||
|
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
assert input['img'] is not None
|
||||||
|
img = input['img'].unsqueeze(0)
|
||||||
|
emb = self.face_model(img).detach().cpu().numpy()
|
||||||
|
emb /= np.sqrt(np.sum(emb**2, -1, keepdims=True)) # l2 norm
|
||||||
|
return {OutputKeys.IMG_EMBEDDING: emb}
|
||||||
|
|
||||||
|
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
return inputs
|
||||||
@@ -30,7 +30,7 @@ class ImageCartoonPipeline(Pipeline):
|
|||||||
|
|
||||||
def __init__(self, model: str, **kwargs):
|
def __init__(self, model: str, **kwargs):
|
||||||
"""
|
"""
|
||||||
use `model` and `preprocessor` to create a kws pipeline for prediction
|
use `model` to create a image cartoon pipeline for prediction
|
||||||
Args:
|
Args:
|
||||||
model: model id on modelscope hub.
|
model: model id on modelscope hub.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ class ImageColorEnhancePipeline(Pipeline):
|
|||||||
ImageColorEnhanceFinetunePreprocessor] = None,
|
ImageColorEnhanceFinetunePreprocessor] = None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""
|
"""
|
||||||
use `model` and `preprocessor` to create a kws pipeline for prediction
|
use `model` and `preprocessor` to create a image color enhance pipeline for prediction
|
||||||
Args:
|
Args:
|
||||||
model: model id on modelscope hub.
|
model: model id on modelscope hub.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ class ImageColorizationPipeline(Pipeline):
|
|||||||
|
|
||||||
def __init__(self, model: str, **kwargs):
|
def __init__(self, model: str, **kwargs):
|
||||||
"""
|
"""
|
||||||
use `model` to create a kws pipeline for prediction
|
use `model` to create a image colorization pipeline for prediction
|
||||||
Args:
|
Args:
|
||||||
model: model id on modelscope hub.
|
model: model id on modelscope hub.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ class ImageMattingPipeline(Pipeline):
|
|||||||
|
|
||||||
def __init__(self, model: str, **kwargs):
|
def __init__(self, model: str, **kwargs):
|
||||||
"""
|
"""
|
||||||
use `model` and `preprocessor` to create a kws pipeline for prediction
|
use `model` to create a image matting pipeline for prediction
|
||||||
Args:
|
Args:
|
||||||
model: model id on modelscope hub.
|
model: model id on modelscope hub.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ class ImageSuperResolutionPipeline(Pipeline):
|
|||||||
|
|
||||||
def __init__(self, model: str, **kwargs):
|
def __init__(self, model: str, **kwargs):
|
||||||
"""
|
"""
|
||||||
use `model` to create a kws pipeline for prediction
|
use `model` to create a image super resolution pipeline for prediction
|
||||||
Args:
|
Args:
|
||||||
model: model id on modelscope hub.
|
model: model id on modelscope hub.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ class OCRDetectionPipeline(Pipeline):
|
|||||||
|
|
||||||
def __init__(self, model: str, **kwargs):
|
def __init__(self, model: str, **kwargs):
|
||||||
"""
|
"""
|
||||||
use `model` and `preprocessor` to create a kws pipeline for prediction
|
use `model` to create a OCR detection pipeline for prediction
|
||||||
Args:
|
Args:
|
||||||
model: model id on modelscope hub.
|
model: model id on modelscope hub.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ class StyleTransferPipeline(Pipeline):
|
|||||||
|
|
||||||
def __init__(self, model: str, **kwargs):
|
def __init__(self, model: str, **kwargs):
|
||||||
"""
|
"""
|
||||||
use `model` and `preprocessor` to create a kws pipeline for prediction
|
use `model` to create a style transfer pipeline for prediction
|
||||||
Args:
|
Args:
|
||||||
model: model id on modelscope hub.
|
model: model id on modelscope hub.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ class VirtualTryonPipeline(Pipeline):
|
|||||||
|
|
||||||
def __init__(self, model: str, **kwargs):
|
def __init__(self, model: str, **kwargs):
|
||||||
"""
|
"""
|
||||||
use `model` to create a kws pipeline for prediction
|
use `model` to create a virtual tryon pipeline for prediction
|
||||||
Args:
|
Args:
|
||||||
model: model id on modelscope hub.
|
model: model id on modelscope hub.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -28,6 +28,8 @@ class CVTasks(object):
|
|||||||
ocr_detection = 'ocr-detection'
|
ocr_detection = 'ocr-detection'
|
||||||
action_recognition = 'action-recognition'
|
action_recognition = 'action-recognition'
|
||||||
video_embedding = 'video-embedding'
|
video_embedding = 'video-embedding'
|
||||||
|
face_detection = 'face-detection'
|
||||||
|
face_recognition = 'face-recognition'
|
||||||
image_color_enhance = 'image-color-enhance'
|
image_color_enhance = 'image-color-enhance'
|
||||||
virtual_tryon = 'virtual-tryon'
|
virtual_tryon = 'virtual-tryon'
|
||||||
image_colorization = 'image-colorization'
|
image_colorization = 'image-colorization'
|
||||||
|
|||||||
84
tests/pipelines/test_face_detection.py
Normal file
84
tests/pipelines/test_face_detection.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
import os.path as osp
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from modelscope.fileio import File
|
||||||
|
from modelscope.msdatasets import MsDataset
|
||||||
|
from modelscope.outputs import OutputKeys
|
||||||
|
from modelscope.pipelines import pipeline
|
||||||
|
from modelscope.utils.constant import ModelFile, Tasks
|
||||||
|
from modelscope.utils.test_utils import test_level
|
||||||
|
|
||||||
|
|
||||||
|
class FaceDetectionTest(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.model_id = 'damo/cv_resnet_facedetection_scrfd10gkps'
|
||||||
|
|
||||||
|
def show_result(self, img_path, bboxes, kpss, scores):
|
||||||
|
bboxes = np.array(bboxes)
|
||||||
|
kpss = np.array(kpss)
|
||||||
|
scores = np.array(scores)
|
||||||
|
img = cv2.imread(img_path)
|
||||||
|
assert img is not None, f"Can't read img: {img_path}"
|
||||||
|
for i in range(len(scores)):
|
||||||
|
bbox = bboxes[i].astype(np.int32)
|
||||||
|
kps = kpss[i].reshape(-1, 2).astype(np.int32)
|
||||||
|
score = scores[i]
|
||||||
|
x1, y1, x2, y2 = bbox
|
||||||
|
cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 2)
|
||||||
|
for kp in kps:
|
||||||
|
cv2.circle(img, tuple(kp), 1, (0, 0, 255), 1)
|
||||||
|
cv2.putText(
|
||||||
|
img,
|
||||||
|
f'{score:.2f}', (x1, y2),
|
||||||
|
1,
|
||||||
|
1.0, (0, 255, 0),
|
||||||
|
thickness=1,
|
||||||
|
lineType=8)
|
||||||
|
cv2.imwrite('result.png', img)
|
||||||
|
print(
|
||||||
|
f'Found {len(scores)} faces, output written to {osp.abspath("result.png")}'
|
||||||
|
)
|
||||||
|
|
||||||
|
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||||
|
def test_run_with_dataset(self):
|
||||||
|
input_location = ['data/test/images/face_detection.png']
|
||||||
|
# alternatively:
|
||||||
|
# input_location = '/dir/to/images'
|
||||||
|
|
||||||
|
dataset = MsDataset.load(input_location, target='image')
|
||||||
|
face_detection = pipeline(Tasks.face_detection, model=self.model_id)
|
||||||
|
# note that for dataset output, the inference-output is a Generator that can be iterated.
|
||||||
|
result = face_detection(dataset)
|
||||||
|
result = next(result)
|
||||||
|
self.show_result(input_location[0], result[OutputKeys.BOXES],
|
||||||
|
result[OutputKeys.KEYPOINTS],
|
||||||
|
result[OutputKeys.SCORES])
|
||||||
|
|
||||||
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||||
|
def test_run_modelhub(self):
|
||||||
|
face_detection = pipeline(Tasks.face_detection, model=self.model_id)
|
||||||
|
img_path = 'data/test/images/face_detection.png'
|
||||||
|
|
||||||
|
result = face_detection(img_path)
|
||||||
|
self.show_result(img_path, result[OutputKeys.BOXES],
|
||||||
|
result[OutputKeys.KEYPOINTS],
|
||||||
|
result[OutputKeys.SCORES])
|
||||||
|
|
||||||
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||||
|
def test_run_modelhub_default_model(self):
|
||||||
|
face_detection = pipeline(Tasks.face_detection)
|
||||||
|
img_path = 'data/test/images/face_detection.png'
|
||||||
|
result = face_detection(img_path)
|
||||||
|
self.show_result(img_path, result[OutputKeys.BOXES],
|
||||||
|
result[OutputKeys.KEYPOINTS],
|
||||||
|
result[OutputKeys.SCORES])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
42
tests/pipelines/test_face_recognition.py
Normal file
42
tests/pipelines/test_face_recognition.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
import os.path as osp
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from modelscope.fileio import File
|
||||||
|
from modelscope.msdatasets import MsDataset
|
||||||
|
from modelscope.outputs import OutputKeys
|
||||||
|
from modelscope.pipelines import pipeline
|
||||||
|
from modelscope.utils.constant import ModelFile, Tasks
|
||||||
|
from modelscope.utils.test_utils import test_level
|
||||||
|
|
||||||
|
|
||||||
|
class FaceRecognitionTest(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.recog_model_id = 'damo/cv_ir101_facerecognition_cfglint'
|
||||||
|
self.det_model_id = 'damo/cv_resnet_facedetection_scrfd10gkps'
|
||||||
|
|
||||||
|
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||||
|
def test_face_compare(self):
|
||||||
|
img1 = 'data/test/images/face_recognition_1.png'
|
||||||
|
img2 = 'data/test/images/face_recognition_2.png'
|
||||||
|
|
||||||
|
face_detection = pipeline(
|
||||||
|
Tasks.face_detection, model=self.det_model_id)
|
||||||
|
face_recognition = pipeline(
|
||||||
|
Tasks.face_recognition,
|
||||||
|
face_detection=face_detection,
|
||||||
|
model=self.recog_model_id)
|
||||||
|
# note that for dataset output, the inference-output is a Generator that can be iterated.
|
||||||
|
emb1 = face_recognition(img1)[OutputKeys.IMG_EMBEDDING]
|
||||||
|
emb2 = face_recognition(img2)[OutputKeys.IMG_EMBEDDING]
|
||||||
|
sim = np.dot(emb1[0], emb2[0])
|
||||||
|
print(f'Cos similarity={sim:.3f}, img1:{img1} img2:{img2}')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user