bad_image_detecting模型支持在数据集上validation功能

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11761935
This commit is contained in:
zhongning.hzn
2023-02-24 14:23:57 +08:00
committed by wenmeng.zwm
parent 058f772d16
commit 8411645524
3 changed files with 16 additions and 11 deletions

View File

@@ -70,6 +70,7 @@ task_default_metrics = {
[Metrics.image_quality_assessment_degradation_metric],
Tasks.image_quality_assessment_mos:
[Metrics.image_quality_assessment_mos_metric],
Tasks.bad_image_detecting: [Metrics.accuracy],
}

View File

@@ -10,6 +10,7 @@ from modelscope.metainfo import Models
from modelscope.models.base import Tensor
from modelscope.models.base.base_torch_model import TorchModel
from modelscope.models.builder import MODELS
from modelscope.outputs import OutputKeys
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
@@ -52,15 +53,14 @@ class BadImageDetecting(TorchModel):
return {'output': ret}
def _evaluate_postprocess(self, input: Tensor,
target: Tensor) -> Dict[str, list]:
def _evaluate_postprocess(self, input: Tensor) -> Dict[str, list]:
torch.cuda.empty_cache()
with torch.no_grad():
preds = self.model(input)
_, pred_ = torch.max(preds, dim=1)
del input
torch.cuda.empty_cache()
return {'pred': pred_, 'target': target}
return {OutputKeys.LABEL: pred_}
def forward(self, inputs: Dict[str,
Tensor]) -> Dict[str, Union[list, Tensor]]:
@@ -74,7 +74,8 @@ class BadImageDetecting(TorchModel):
"""
if self.training:
return self._train_forward(**inputs)
elif 'target' in inputs:
return self._evaluate_postprocess(**inputs)
elif OutputKeys.LABEL in inputs:
infeat = inputs['input']
return self._evaluate_postprocess(infeat)
else:
return self._inference_forward(**inputs)

View File

@@ -7,8 +7,9 @@ from modelscope.metainfo import Models
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
from modelscope.msdatasets.task_datasets.torch_base_dataset import \
TorchTaskDataset
from modelscope.outputs import OutputKeys
from modelscope.preprocessors import LoadImage
from modelscope.preprocessors.cv.bad_image_preprocessor import \
from modelscope.preprocessors.cv.bad_image_detecting_preprocessor import \
BadImageDetectingPreprocessor
from modelscope.utils.constant import Tasks
@@ -31,9 +32,11 @@ class BadImageDetectingDataset(TorchTaskDataset):
# Load input video paths.
item_dict = self.dataset[index]
iterm_label = item_dict['label']
img = LoadImage.convert_to_ndarray(input)
iterm_label = item_dict['category']
# print(input)
img = LoadImage.convert_to_ndarray(item_dict['image:FILE'])
img = self.preprocessor(img)
return {'input': img['input'], 'target': iterm_label}
return {
'input': img['input'].squeeze(0),
OutputKeys.LABEL: iterm_label
}