mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 04:29:22 +01:00
substitute face detection model in skin_retouching_pipeline.py
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10909902
This commit is contained in:
@@ -15,11 +15,10 @@ from modelscope.models.cv.skin_retouching.detection_model.detection_unet_in impo
|
||||
DetectionUNet
|
||||
from modelscope.models.cv.skin_retouching.inpainting_model.inpainting_unet import \
|
||||
RetouchingNet
|
||||
from modelscope.models.cv.skin_retouching.retinaface.predict_single import \
|
||||
Model
|
||||
from modelscope.models.cv.skin_retouching.unet_deploy import UNet
|
||||
from modelscope.models.cv.skin_retouching.utils import * # noqa F403
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.pipelines.base import Input, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors import LoadImage
|
||||
@@ -48,8 +47,6 @@ class SkinRetouchingPipeline(Pipeline):
|
||||
|
||||
device = create_device(self.device_name)
|
||||
model_path = os.path.join(self.model, ModelFile.TORCH_MODEL_FILE)
|
||||
detector_model_path = os.path.join(
|
||||
self.model, 'retinaface_resnet50_2020-07-20_old_torch.pth')
|
||||
local_model_path = os.path.join(self.model, 'joint_20210926.pth')
|
||||
skin_model_path = os.path.join(self.model, ModelFile.TF_GRAPH_FILE)
|
||||
|
||||
@@ -58,10 +55,9 @@ class SkinRetouchingPipeline(Pipeline):
|
||||
torch.load(model_path, map_location='cpu')['generator'])
|
||||
self.generator.eval()
|
||||
|
||||
self.detector = Model(max_size=512, device=device)
|
||||
state_dict = torch.load(detector_model_path, map_location='cpu')
|
||||
self.detector.load_state_dict(state_dict)
|
||||
self.detector.eval()
|
||||
det_model_id = 'damo/cv_resnet50_face-detection_retinaface'
|
||||
self.detector = pipeline(Tasks.face_detection, model=det_model_id)
|
||||
self.detector.detector.to(device)
|
||||
|
||||
self.local_model_path = local_model_path
|
||||
ckpt_dict_load = torch.load(self.local_model_path, map_location='cpu')
|
||||
@@ -136,9 +132,18 @@ class SkinRetouchingPipeline(Pipeline):
|
||||
(rgb_image.shape[0], rgb_image.shape[1], 3),
|
||||
dtype=np.float32) * 0.5
|
||||
|
||||
results = self.detector.predict_jsons(
|
||||
rgb_image
|
||||
) # list, [{'bbox':, [x1, y1, x2, y2], 'score'...}, ...]
|
||||
det_results = self.detector(rgb_image)
|
||||
# list, [{'bbox':, [x1, y1, x2, y2], 'score'...}, ...]
|
||||
results = []
|
||||
for i in range(len(det_results['scores'])):
|
||||
info_dict = {}
|
||||
info_dict['bbox'] = np.array(det_results['boxes'][i]).astype(
|
||||
np.int32).tolist()
|
||||
info_dict['score'] = det_results['scores'][i]
|
||||
info_dict['landmarks'] = np.array(
|
||||
det_results['keypoints'][i]).astype(np.int32).reshape(
|
||||
5, 2).tolist()
|
||||
results.append(info_dict)
|
||||
|
||||
crop_bboxes = get_crop_bbox(results)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user