From 8411645524360ef3a8135a4d2029492efe0135cd Mon Sep 17 00:00:00 2001 From: "zhongning.hzn" Date: Fri, 24 Feb 2023 14:23:57 +0800 Subject: [PATCH] =?UTF-8?q?bad=5Fimage=5Fdetecting=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E6=94=AF=E6=8C=81=E5=9C=A8=E6=95=B0=E6=8D=AE=E9=9B=86=E4=B8=8A?= =?UTF-8?q?validation=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11761935 --- modelscope/metrics/builder.py | 1 + .../cv/bad_image_detecting/bad_image_detecting.py | 11 ++++++----- .../bad_image_detecting_dataset.py | 15 +++++++++------ 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/modelscope/metrics/builder.py b/modelscope/metrics/builder.py index 76278288..0357fa25 100644 --- a/modelscope/metrics/builder.py +++ b/modelscope/metrics/builder.py @@ -70,6 +70,7 @@ task_default_metrics = { [Metrics.image_quality_assessment_degradation_metric], Tasks.image_quality_assessment_mos: [Metrics.image_quality_assessment_mos_metric], + Tasks.bad_image_detecting: [Metrics.accuracy], } diff --git a/modelscope/models/cv/bad_image_detecting/bad_image_detecting.py b/modelscope/models/cv/bad_image_detecting/bad_image_detecting.py index f8cb866c..f173f479 100644 --- a/modelscope/models/cv/bad_image_detecting/bad_image_detecting.py +++ b/modelscope/models/cv/bad_image_detecting/bad_image_detecting.py @@ -10,6 +10,7 @@ from modelscope.metainfo import Models from modelscope.models.base import Tensor from modelscope.models.base.base_torch_model import TorchModel from modelscope.models.builder import MODELS +from modelscope.outputs import OutputKeys from modelscope.utils.config import Config from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.logger import get_logger @@ -52,15 +53,14 @@ class BadImageDetecting(TorchModel): return {'output': ret} - def _evaluate_postprocess(self, input: Tensor, - target: Tensor) -> Dict[str, list]: + def _evaluate_postprocess(self, input: Tensor) -> Dict[str, list]: torch.cuda.empty_cache() with torch.no_grad(): preds = self.model(input) _, pred_ = torch.max(preds, dim=1) del input torch.cuda.empty_cache() - return {'pred': pred_, 'target': target} + return {OutputKeys.LABEL: pred_} def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, Union[list, Tensor]]: @@ -74,7 +74,8 @@ class BadImageDetecting(TorchModel): """ if self.training: return self._train_forward(**inputs) - elif 'target' in inputs: - return self._evaluate_postprocess(**inputs) + elif OutputKeys.LABEL in inputs: + infeat = inputs['input'] + return self._evaluate_postprocess(infeat) else: return self._inference_forward(**inputs) diff --git a/modelscope/msdatasets/task_datasets/bad_image_detecting/bad_image_detecting_dataset.py b/modelscope/msdatasets/task_datasets/bad_image_detecting/bad_image_detecting_dataset.py index 91ef5d13..f3cd9a2f 100644 --- a/modelscope/msdatasets/task_datasets/bad_image_detecting/bad_image_detecting_dataset.py +++ b/modelscope/msdatasets/task_datasets/bad_image_detecting/bad_image_detecting_dataset.py @@ -7,8 +7,9 @@ from modelscope.metainfo import Models from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS from modelscope.msdatasets.task_datasets.torch_base_dataset import \ TorchTaskDataset +from modelscope.outputs import OutputKeys from modelscope.preprocessors import LoadImage -from modelscope.preprocessors.cv.bad_image_preprocessor import \ +from modelscope.preprocessors.cv.bad_image_detecting_preprocessor import \ BadImageDetectingPreprocessor from modelscope.utils.constant import Tasks @@ -31,9 +32,11 @@ class BadImageDetectingDataset(TorchTaskDataset): # Load input video paths. item_dict = self.dataset[index] - iterm_label = item_dict['label'] - - img = LoadImage.convert_to_ndarray(input) + iterm_label = item_dict['category'] + # print(input) + img = LoadImage.convert_to_ndarray(item_dict['image:FILE']) img = self.preprocessor(img) - - return {'input': img['input'], 'target': iterm_label} + return { + 'input': img['input'].squeeze(0), + OutputKeys.LABEL: iterm_label + }