substitute face detection model in skin_retouching_pipeline.py

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10909902
This commit is contained in:
ly261666
2022-12-02 15:41:08 +08:00
committed by wenmeng.zwm
parent c9a6b887a2
commit 4208d51e23

View File

@@ -15,11 +15,10 @@ from modelscope.models.cv.skin_retouching.detection_model.detection_unet_in impo
DetectionUNet
from modelscope.models.cv.skin_retouching.inpainting_model.inpainting_unet import \
RetouchingNet
from modelscope.models.cv.skin_retouching.retinaface.predict_single import \
Model
from modelscope.models.cv.skin_retouching.unet_deploy import UNet
from modelscope.models.cv.skin_retouching.utils import * # noqa F403
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import LoadImage
@@ -48,8 +47,6 @@ class SkinRetouchingPipeline(Pipeline):
device = create_device(self.device_name)
model_path = os.path.join(self.model, ModelFile.TORCH_MODEL_FILE)
detector_model_path = os.path.join(
self.model, 'retinaface_resnet50_2020-07-20_old_torch.pth')
local_model_path = os.path.join(self.model, 'joint_20210926.pth')
skin_model_path = os.path.join(self.model, ModelFile.TF_GRAPH_FILE)
@@ -58,10 +55,9 @@ class SkinRetouchingPipeline(Pipeline):
torch.load(model_path, map_location='cpu')['generator'])
self.generator.eval()
self.detector = Model(max_size=512, device=device)
state_dict = torch.load(detector_model_path, map_location='cpu')
self.detector.load_state_dict(state_dict)
self.detector.eval()
det_model_id = 'damo/cv_resnet50_face-detection_retinaface'
self.detector = pipeline(Tasks.face_detection, model=det_model_id)
self.detector.detector.to(device)
self.local_model_path = local_model_path
ckpt_dict_load = torch.load(self.local_model_path, map_location='cpu')
@@ -136,9 +132,18 @@ class SkinRetouchingPipeline(Pipeline):
(rgb_image.shape[0], rgb_image.shape[1], 3),
dtype=np.float32) * 0.5
results = self.detector.predict_jsons(
rgb_image
) # list, [{'bbox':, [x1, y1, x2, y2], 'score'...}, ...]
det_results = self.detector(rgb_image)
# list, [{'bbox':, [x1, y1, x2, y2], 'score'...}, ...]
results = []
for i in range(len(det_results['scores'])):
info_dict = {}
info_dict['bbox'] = np.array(det_results['boxes'][i]).astype(
np.int32).tolist()
info_dict['score'] = det_results['scores'][i]
info_dict['landmarks'] = np.array(
det_results['keypoints'][i]).astype(np.int32).reshape(
5, 2).tolist()
results.append(info_dict)
crop_bboxes = get_crop_bbox(results)