diff --git a/modelscope/metrics/builder.py b/modelscope/metrics/builder.py index 76278288..0357fa25 100644 --- a/modelscope/metrics/builder.py +++ b/modelscope/metrics/builder.py @@ -70,6 +70,7 @@ task_default_metrics = { [Metrics.image_quality_assessment_degradation_metric], Tasks.image_quality_assessment_mos: [Metrics.image_quality_assessment_mos_metric], + Tasks.bad_image_detecting: [Metrics.accuracy], } diff --git a/modelscope/models/cv/bad_image_detecting/bad_image_detecting.py b/modelscope/models/cv/bad_image_detecting/bad_image_detecting.py index f8cb866c..f173f479 100644 --- a/modelscope/models/cv/bad_image_detecting/bad_image_detecting.py +++ b/modelscope/models/cv/bad_image_detecting/bad_image_detecting.py @@ -10,6 +10,7 @@ from modelscope.metainfo import Models from modelscope.models.base import Tensor from modelscope.models.base.base_torch_model import TorchModel from modelscope.models.builder import MODELS +from modelscope.outputs import OutputKeys from modelscope.utils.config import Config from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.logger import get_logger @@ -52,15 +53,14 @@ class BadImageDetecting(TorchModel): return {'output': ret} - def _evaluate_postprocess(self, input: Tensor, - target: Tensor) -> Dict[str, list]: + def _evaluate_postprocess(self, input: Tensor) -> Dict[str, list]: torch.cuda.empty_cache() with torch.no_grad(): preds = self.model(input) _, pred_ = torch.max(preds, dim=1) del input torch.cuda.empty_cache() - return {'pred': pred_, 'target': target} + return {OutputKeys.LABEL: pred_} def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, Union[list, Tensor]]: @@ -74,7 +74,8 @@ class BadImageDetecting(TorchModel): """ if self.training: return self._train_forward(**inputs) - elif 'target' in inputs: - return self._evaluate_postprocess(**inputs) + elif OutputKeys.LABEL in inputs: + infeat = inputs['input'] + return self._evaluate_postprocess(infeat) else: return self._inference_forward(**inputs) diff --git a/modelscope/msdatasets/task_datasets/bad_image_detecting/bad_image_detecting_dataset.py b/modelscope/msdatasets/task_datasets/bad_image_detecting/bad_image_detecting_dataset.py index 91ef5d13..f3cd9a2f 100644 --- a/modelscope/msdatasets/task_datasets/bad_image_detecting/bad_image_detecting_dataset.py +++ b/modelscope/msdatasets/task_datasets/bad_image_detecting/bad_image_detecting_dataset.py @@ -7,8 +7,9 @@ from modelscope.metainfo import Models from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS from modelscope.msdatasets.task_datasets.torch_base_dataset import \ TorchTaskDataset +from modelscope.outputs import OutputKeys from modelscope.preprocessors import LoadImage -from modelscope.preprocessors.cv.bad_image_preprocessor import \ +from modelscope.preprocessors.cv.bad_image_detecting_preprocessor import \ BadImageDetectingPreprocessor from modelscope.utils.constant import Tasks @@ -31,9 +32,11 @@ class BadImageDetectingDataset(TorchTaskDataset): # Load input video paths. item_dict = self.dataset[index] - iterm_label = item_dict['label'] - - img = LoadImage.convert_to_ndarray(input) + iterm_label = item_dict['category'] + # print(input) + img = LoadImage.convert_to_ndarray(item_dict['image:FILE']) img = self.preprocessor(img) - - return {'input': img['input'], 'target': iterm_label} + return { + 'input': img['input'].squeeze(0), + OutputKeys.LABEL: iterm_label + }